diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index 35891307e..761e48e59 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,3 +1,4 @@ from .pipeline import PPInferEngine -__all__ = ["PPInferEngine"] + +__all__ = ['PPInferEngine'] diff --git a/colossalai/inference/pipeline/README.md b/colossalai/inference/pipeline/README.md index a90d5d6da..f9bb35cc4 100644 --- a/colossalai/inference/pipeline/README.md +++ b/colossalai/inference/pipeline/README.md @@ -17,7 +17,7 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py). 1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks: - - Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`. + - Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`. - Run the pipeline inference model. 2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks: @@ -31,54 +31,53 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManag ### Example ```python -from colossalai.pipeline import PPInferEngine -# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example. -model = LlamaForCausalLM.from_pretrained('/path/to/model') -inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt") -engine = PPInferEngine( - pp_size=2, - dtype='fp16', - micro_batch_size=1, - new_length=10, - model=model, - model_policy=LlamaForCausalLMPipelinePolicy()) +from colossalai.inference import PPInferEngine +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy +import colossalai +from transformers import LlamaForCausalLM, LlamaTokenizer -output = engine.inference([inputs]) +colossalai.launch_from_torch(config={}) -``` +model = LlamaForCausalLM.from_pretrained("/path/to/model") +tokenizer = LlamaTokenizer.from_pretrained("/path/to/model") -### Quick start -```shell -cd benchmark -sh run.sh +# assume the model is inferred with 2 pipeline stages +inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32) + +input = ["Introduce a landmark in London","Introduce a landmark in Singapore"] +data = tokenizer(input, return_tensors='pt') +output = inferengine.inference(data.to('cuda')) +print(tokenizer.batch_decode(output)) ``` ## Performance -We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G. +We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G. -### Llama Throughput(tokens/s) +### Llama Throughput (tokens/s) | input length=1024, output length=128 -#### 7b, fp16 +#### A10 7b, fp16 | batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)| | :---: | :---: | :---: | :---: | :---: | :---: | :---:| -| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM | -| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | -| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 | -| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM | +| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM | +| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | -#### 7b, fp32 +#### A10 13b, fp16 | batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | | :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 | -| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM | -| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 | -| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM | +| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | +| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | -#### 13b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | -| :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 | -| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM | -| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 | -| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM | + +#### A800 7b, fp16 +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +| :---: | :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | +| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | + + +#### A800 13b, fp16 +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +| :---: | :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 | +| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 | diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 9c47909f7..8392d0a1e 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -7,7 +7,7 @@ import transformers import colossalai from colossalai.inference import PPInferEngine -from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 @@ -117,8 +117,11 @@ if __name__ == "__main__": micro_batch_size=args.mb_size, new_length=args.new_length, model=model, - model_policy=LlamaForCausalLMPipelinePolicy(), + model_policy=LlamaModelInferPolicy(), verbose=True, + max_batch_size=args.mb_size, + max_input_len=args.seq_len, + max_output_len=args.seq_len + args.new_length + 256, ) data = data_gen(args.batch_size, args.seq_len) diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/colossalai/inference/pipeline/benchmark/run.sh index 7d8da8586..e3c33bb88 100644 --- a/colossalai/inference/pipeline/benchmark/run.sh +++ b/colossalai/inference/pipeline/benchmark/run.sh @@ -1,7 +1,7 @@ script_dir=$(cd "$(dirname "$0")" && pwd) cd "${script_dir}" -# 7b, fp32, 2 gpu, 1024, 128 +# 7b, fp16, 2 gpu, 1024, 128 for BATCH_SIZE in 2 4 8 16; do CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="7b" \ @@ -13,7 +13,7 @@ for BATCH_SIZE in 2 4 8 16; do --pp_size=2 done -# 7b, fp32, 2 gpu, 512, 512 +# 7b, fp16, 2 gpu, 512, 512 for BATCH_SIZE in 2 4 8 16 32; do CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="7b" \ @@ -25,7 +25,7 @@ for BATCH_SIZE in 2 4 8 16 32; do --pp_size=2 done -# 7b, fp32, 2 gpu, 1024, 128 +# 7b, fp16, 2 gpu, 1024, 128 for BATCH_SIZE in 2 4 8; do CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="13b" \ diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 4f42385ca..480ac5dc7 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from transformers.tokenization_utils_base import BatchEncoding from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.schedule.generate import GenerateSchedule @@ -7,6 +8,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy +from ..tensor_parallel.kvcache_manager import MemoryManager from .microbatch_manager import MicroBatchManager @@ -23,20 +25,29 @@ class PPInferEngine: micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. new_length (int): the new length of the input sequence. early_stopping (bool): whether to stop early. + max_batch_size (int): the maximum batch size. + max_input_len (int): the maximum input length. + max_output_len (int): the maximum output length. Example: ```python - from colossalai.ppinference import PPInferEngine - from transformers import GPT2LMHeadModel, GPT2Tokenizer + from colossalai.inference import PPInferEngine + from colossalai.inference.pipeline.policies import LlamaModelInferPolicy + import colossalai + from transformers import LlamaForCausalLM, LlamaTokenizer - model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') - # assume the model is infered with 4 pipeline stages - inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding}) + colossalai.launch_from_torch(config={}) + + model = LlamaForCausalLM.from_pretrained("your_path_to_model") + tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") + # assume the model is infered with 2 pipeline stages + inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8) + + input = ["Introduce a landmark in China ","Introduce a landmark in China "] + data = tokenizer(input, return_tensors='pt') + output = inferengine.inference([data.to('cuda').data]) - input = ["Hello, my dog is cute, and I like"] - tokenized_input = tokenizer(input, return_tensors='pt') - output = engine.inference([tokenized_input]) ``` """ @@ -51,6 +62,9 @@ class PPInferEngine: new_length: int = 32, micro_batch_size: int = 1, micro_batch_buffer_size: int = None, + max_batch_size: int = 4, + max_input_len: int = 32, + max_output_len: int = 32, verbose: bool = False, # TODO: implement early_stopping, and various gerneration options early_stopping: bool = False, @@ -58,24 +72,53 @@ class PPInferEngine: num_beams: int = 1, ) -> None: assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided." + assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" + + max_output_len = max(max_output_len, max_input_len + new_length) + self.pp_size = pp_size + if dtype == "fp16": + self.dtype = torch.float16 + model.half() + elif dtype == "bf16": + self.dtype = torch.bfloat16 + model.to(torch.bfloat16) + else: + self.dtype = torch.float32 self.pg_mesh = ProcessGroupMesh(pp_size) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) + self.model = pp_model or self._shardformer(model, model_policy) + self.cache_manager_list = [ + self._init_manager(max_batch_size, max_input_len, max_output_len) + for _ in range(micro_batch_buffer_size or pp_size) + ] self.mb_manager = MicroBatchManager( - self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size + self.stage_manager.stage, + new_length, + micro_batch_size, + micro_batch_buffer_size or pp_size, + max_input_len, + max_output_len, + self.cache_manager_list, ) self.verbose = verbose self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) - assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" - if dtype == "fp16": - model.half() - elif dtype == "bf16": - model.to(torch.bfloat16) - self.model = pp_model or self._shardformer(model, model_policy) - def inference(self, input_list): - out, timestamp = self.schedule.generate_step(self.model, iter(input_list)) + """ + Args: + input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`. + + Returns: + out (list): a list of output data, each element is a list of token. + timestamp (float): the time cost of the inference, only return when verbose is `True`. + """ + assert isinstance( + input_list, (BatchEncoding, dict) + ), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}." + if isinstance(input_list, BatchEncoding): + input_list = input_list.data + out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) if self.verbose: return out, timestamp else: @@ -95,3 +138,17 @@ class PPInferEngine: shardformer = ShardFormer(shard_config=shardconfig) shard_model, _ = shardformer.optimize(model, model_policy) return shard_model.cuda() + + def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: + max_total_token_num = max_batch_size * (max_input_len + max_output_len) + head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads + head_num = self.model.config.num_attention_heads + num_hidden_layers = ( + self.model.config.num_hidden_layers + if hasattr(self.model.config, "num_hidden_layers") + else self.model.config.num_layers + ) + layer_num = num_hidden_layers // self.pp_size + + cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num) + return cache_manager diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 49d1bf3f4..2bf52161d 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -1,8 +1,11 @@ from enum import Enum -from typing import Dict, Tuple +from typing import Dict import torch +from ..tensor_parallel.batch_infer_state import BatchInferState +from ..tensor_parallel.kvcache_manager import MemoryManager + __all__ = "MicroBatchManager" @@ -27,21 +30,20 @@ class MicroBatchDescription: def __init__( self, inputs_dict: Dict[str, torch.Tensor], - output_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, new_length: int, ) -> None: - assert output_dict.get("hidden_states") is not None - self.mb_length = output_dict["hidden_states"].shape[-2] + self.mb_length = inputs_dict["input_ids"].shape[-1] self.target_length = self.mb_length + new_length - self.kv_cache = () + self.infer_state = BatchInferState.init_from_batch( + batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager + ) + # print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}") - def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): - if output_dict is not None: - self._update_kvcache(output_dict["past_key_values"]) - - def _update_kvcache(self, kv_cache: Tuple): - assert type(kv_cache) == tuple - self.kv_cache = kv_cache + def update(self, *args, **kwargs): + pass @property def state(self): @@ -80,17 +82,21 @@ class HeadMicroBatchDescription(MicroBatchDescription): """ def __init__( - self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + new_length: int, ) -> None: - super().__init__(inputs_dict, output_dict, new_length) + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) assert inputs_dict is not None assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None self.input_ids = inputs_dict["input_ids"] self.attn_mask = inputs_dict["attention_mask"] self.new_tokens = None - def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): - super().update(output_dict, new_token) + def update(self, new_token: torch.Tensor = None): if new_token is not None: self._update_newtokens(new_token) if self.state is not Status.DONE and new_token is not None: @@ -125,16 +131,17 @@ class BodyMicroBatchDescription(MicroBatchDescription): Args: inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. - output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. """ def __init__( - self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + new_length: int, ) -> None: - super().__init__(inputs_dict, output_dict, new_length) - - def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): - super().update(output_dict, new_token) + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) @property def cur_length(self): @@ -142,10 +149,7 @@ class BodyMicroBatchDescription(MicroBatchDescription): When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 """ - if len(self.kv_cache) == 0: - return self.mb_length - else: - return self.kv_cache[0][0].shape[-2] + 1 + return self.infer_state.seq_len.max().item() class MicroBatchManager: @@ -160,16 +164,38 @@ class MicroBatchManager: """ - def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): + def __init__( + self, + stage: int, + new_length: int, + micro_batch_size: int, + micro_batch_buffer_size: int, + max_input_len: int, + max_output_len: int, + cache_manager_list: MemoryManager, + ): self.stage = stage self.new_length = new_length self.micro_batch_size = micro_batch_size self.buffer_size = micro_batch_buffer_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.cache_manager_list = cache_manager_list self.mb_descrption_buffer = {} self.new_tokens_buffer = {} self.idx = 0 - def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): + if self.stage == 0: + self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length + ) + else: + self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length + ) + + def step(self, new_token: torch.Tensor = None): """ Update the state if microbatch manager, 2 conditions. 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. @@ -181,11 +207,7 @@ class MicroBatchManager: new_token (torch.Tensor): the new token generated by current stage. """ # Add descrption first if the descrption is None - if inputs_dict is None and output_dict is None and new_token is None: - return Status.PREFILL - if self.mb_descrption_buffer.get(self.idx) is None: - self._add_descrption(inputs_dict, output_dict) - self.cur_descrption.update(output_dict, new_token) + self.cur_descrption.update(new_token) return self.cur_state def export_new_tokens(self): @@ -204,16 +226,12 @@ class MicroBatchManager: def clear(self): self.mb_descrption_buffer.clear() + for cache in self.cache_manager_list: + cache.free_all() def next(self): self.idx = (self.idx + 1) % self.buffer_size - def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]): - if self.stage == 0: - self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length) - else: - self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length) - def _remove_descrption(self): self.mb_descrption_buffer.pop(self.idx) @@ -222,10 +240,10 @@ class MicroBatchManager: return self.mb_descrption_buffer.get(self.idx) @property - def cur_kv_cache(self): + def cur_infer_state(self): if self.cur_descrption is None: return None - return self.cur_descrption.kv_cache + return self.cur_descrption.infer_state @property def cur_state(self): diff --git a/colossalai/inference/pipeline/modeling/__init__.py b/colossalai/inference/pipeline/modeling/__init__.py index e69de29bb..239bdebd7 100644 --- a/colossalai/inference/pipeline/modeling/__init__.py +++ b/colossalai/inference/pipeline/modeling/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaInferenceForwards + +__all__ = ["LlamaInferenceForwards"] diff --git a/colossalai/inference/pipeline/modeling/_utils.py b/colossalai/inference/pipeline/modeling/_utils.py new file mode 100644 index 000000000..068b64b4f --- /dev/null +++ b/colossalai/inference/pipeline/modeling/_utils.py @@ -0,0 +1,67 @@ +""" +Utils for model inference +""" +import os + +import torch + +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + + +def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + """ + This function copies the key and value cache to the memory cache + Args: + layer_id : id of current layer + key_buffer : key cache + value_buffer : value cache + context_mem_index : index of memory cache in kv cache manager + mem_manager : cache manager + """ + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + + +def init_to_get_rotary(self, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + self : Model that holds the rotary positional embedding + base : calculation arg + use_elem : activated when using chatglm-based models + """ + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) + + if ntk_alpha is not None: + ntk_alpha = float(ntk_alpha) + assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + + n_elem = self.config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py deleted file mode 100644 index d2bfcb8b6..000000000 --- a/colossalai/inference/pipeline/modeling/gpt2.py +++ /dev/null @@ -1,280 +0,0 @@ -from typing import Dict, List, Optional, Tuple, Union - -import torch -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions -from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model -from transformers.utils import logging - -from colossalai.pipeline.stage_manager import PipelineStageManager - - -class GPT2PipelineForwards: - """ - This class serves as a micro library for forward function substitution of GPT2 models - under pipeline setting. - """ - - @staticmethod - def gpt2_model_forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. - # Please refer to original code of transformers for more details. - logger = logging.get_logger(__name__) - - # Preprocess passed in arguments - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - else: - if hidden_states is None: - raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape[0], input_shape[1] - device = hidden_states.device - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if stage_manager.is_first_stage(): - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - else: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - - # Going through held blocks. - start_idx, end_idx = stage_index[0], stage_index[1] - for i, layer_past in zip(range(start_idx, end_idx), past_key_values): - block = self.h[i] - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - if stage_manager.is_last_stage(): - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - return {"hidden_states": hidden_states, "past_key_values": presents} - - @staticmethod - def gpt2_lmhead_model_forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. - Please refer to original code of transformers for more details. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # If is first stage and after warmup, go throught lm_head first - if stage_manager.is_first_stage() and hidden_states is not None: - lm_logits = self.lm_head(hidden_states) - return {"logits": lm_logits} - - # Not first stage or before warmup, go through gpt2 model - outputs = GPT2PipelineForwards.gpt2_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - - return outputs diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py index f46e1fbdd..9c72b02cc 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -1,158 +1,72 @@ -from typing import List, Optional +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py +from typing import List, Optional, Tuple import torch -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) from transformers.utils import logging +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.pipeline.stage_manager import PipelineStageManager +from ._utils import copy_kv_to_mem_cache -class LlamaPipelineForwards: +try: + from vllm import layernorm_ops, pos_encoding_ops + + rms_norm = layernorm_ops.rms_norm + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True +except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print( + "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" + ) + HAS_VLLM_KERNERL = False + +try: + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaInferenceForwards: """ - This class serves as a micro library for forward function substitution of Llama models - under pipeline setting. + This class holds forwards for llama inference. + We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. """ - def llama_model_forward( - self: LlamaModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): - logger = logging.get_logger(__name__) - - # Preprocess passed in arguments - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - # embed positions, for the first stage, hidden_states is the input embeddings, - # for the other stages, hidden_states is the output of the previous stage - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - start_idx, end_idx = stage_index[0], stage_index[1] - if past_key_values is None: - past_key_values = tuple([None] * (end_idx - start_idx + 1)) - - for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): - decoder_layer = self.layers[idx] - if output_hidden_states: - all_hidden_states += (hidden_states,) - - # past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if stage_manager.is_last_stage(): - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - - # always return dict for imediate stage - return {"hidden_states": hidden_states, "past_key_values": next_cache} - - def llama_for_causal_lm_forward( + @staticmethod + def llama_causal_lm_forward( self: LlamaForCausalLM, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, @@ -164,6 +78,7 @@ class LlamaPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -175,24 +90,7 @@ class LlamaPipelineForwards: config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" + """ logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -210,7 +108,7 @@ class LlamaPipelineForwards: return {"logits": lm_logits} # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = LlamaPipelineForwards.llama_model_forward( + outputs = LlamaInferenceForwards.llama_model_forward( self.model, input_ids=input_ids, attention_mask=attention_mask, @@ -221,9 +119,348 @@ class LlamaPipelineForwards: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + infer_state=infer_state, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, ) return outputs + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + # batch_size = input_ids.shape[0] # input_ids.shape[0] + # print(f"[Before] rank:{torch.distributed.get_rank()}\n->{infer_state}") + + # infer_state = self.infer_state + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager is None or stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + assert stage_manager is not None + assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}" + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if infer_state.is_context_stage is False: + past_key_values_length = infer_state.cache_manager.past_key_values_length + seq_length_with_past = seq_length_with_past + past_key_values_length + + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if use_cache and seq_length != 1: + # NOTE assume prefill stage + # allocate memory block + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + new_shape = [1] * position_ids.dim() + new_shape[0] = batch_size + position_ids = position_ids.repeat(*new_shape).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if infer_state.is_context_stage: + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + # decoder layers + () if output_hidden_states else None + () if output_attentions else None + next_decoder_cache = () if use_cache else None + + infer_state.decode_layer_id = 0 + + start_idx, end_idx = stage_index[0], stage_index[1] + if past_key_values is None: + past_key_values = tuple([None] * (end_idx - start_idx + 1)) + + for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): + decoder_layer = self.layers[idx] + # NOTE: modify here for passing args to decoder layer + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + infer_state.decode_layer_id += 1 + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + next_cache = next_decoder_cache if use_cache else None + + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + + # TODO: fix this to necessary return + # if not return_dict: + # return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=next_cache, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # ) + # print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}") + return {"hidden_states": hidden_states, "past_key_values": next_cache} + + @staticmethod + def llama_decoder_layer_forward( + self: LlamaDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + @staticmethod + def llama_flash_attn_kvcache_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert use_cache is True, "use_cache should be set to True using this llama attention" + + bsz, q_len, _ = hidden_states.size() + + # NOTE might think about better way to handle transposed k and v + # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] + # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len + + cos, sin = infer_state.position_cos, infer_state.position_sin + + llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + + if infer_state.is_context_stage: + # print(f"rank:{torch.distributed.get_rank()}, {infer_state}") + # first token generation + + # copy key and value calculated in current step to memory manager + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + + attn_output = torch.empty_like(query_states) + + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # second token and follows + # kv = torch.stack((key_states, value_states), dim=2) + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + # print(f"rank:{torch.distributed.get_rank()}, {attn_output}") + attn_output = self.o_proj(attn_output) + + # return past_key_value as None + return attn_output, None, None + + +def get_llama_vllm_rmsnorm_forward(): + if HAS_VLLM_KERNERL: + + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward + else: + return None diff --git a/colossalai/inference/pipeline/policies/__init__.py b/colossalai/inference/pipeline/policies/__init__.py new file mode 100644 index 000000000..7271812c5 --- /dev/null +++ b/colossalai/inference/pipeline/policies/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaModelInferPolicy + +__all__ = ["LlamaModelInferPolicy"] diff --git a/colossalai/inference/pipeline/policies/llama.py b/colossalai/inference/pipeline/policies/llama.py new file mode 100644 index 000000000..9f8c93c61 --- /dev/null +++ b/colossalai/inference/pipeline/policies/llama.py @@ -0,0 +1,145 @@ +from functools import partial +from typing import List + +import torch +from torch.nn import Module +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) + +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription + +# import colossalai +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + +from ..modeling._utils import init_to_get_rotary +from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward + +try: + from colossalai.kernel.triton import rmsnorm_forward + + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +def get_triton_rmsnorm_forward(): + if HAS_TRITON_RMSNORM: + + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + + return _triton_rmsnorm_forward + else: + return None + + +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + ], + ) + + self.shard_config._infer() + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + ) + + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaAttention + ) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy + ) + infer_forward = None + if HAS_TRITON_RMSNORM: + infer_forward = get_triton_rmsnorm_forward() + else: + # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 + infer_forward = get_llama_vllm_rmsnorm_forward() + + if infer_forward is not None: + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaRMSNorm + ) + + return policy + + def postprocess(self): + init_to_get_rotary(self.model.model) + return self.model + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_first_stage(): + held_layers.append(self.model.lm_head) + return held_layers diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py deleted file mode 100644 index 51e6425b1..000000000 --- a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py +++ /dev/null @@ -1,74 +0,0 @@ -from functools import partial -from typing import Callable, Dict, List - -from torch import Tensor, nn - -import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from colossalai.shardformer.policies.gpt2 import GPT2Policy - -from ..modeling.gpt2 import GPT2PipelineForwards - - -class GPT2LMHeadModelPipelinePolicy(GPT2Policy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel - - module_policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - addon_module = { - GPT2LMHeadModel: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} - ) - ] - ) - } - module_policy.update(addon_module) - - if self.pipeline_stage_manager is not None: - self.set_pipeline_forward( - model_cls=GPT2LMHeadModel, - new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - policy=module_policy, - ) - return module_policy - - def get_held_layers(self) -> List[nn.Module]: - held_layers = super().get_held_layers() - # make the tie weight lm_head and embedding in the same device to save memory - # if self.pipeline_stage_manager.is_first_stage(): - if self.pipeline_stage_manager.is_first_stage(): - held_layers.append(self.model.lm_head) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """The weights of wte and lm_head are shared.""" - module = self.model - stage_manager = self.pipeline_stage_manager - if stage_manager is not None: - if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): - first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] - return [] - - def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" - if not self.pipeline_stage_manager: - raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "GPT2Model": - module = self.model - else: - module = self.model.transformer - - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/inference/pipeline/policy/llama_ppinfer.py b/colossalai/inference/pipeline/policy/llama_ppinfer.py deleted file mode 100644 index 6e12ed61b..000000000 --- a/colossalai/inference/pipeline/policy/llama_ppinfer.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import List - -from torch.nn import Module - -from colossalai.shardformer.layer import Linear1D_Col -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription -from colossalai.shardformer.policies.llama import LlamaPolicy - -from ..modeling.llama import LlamaPipelineForwards - - -class LlamaForCausalLMPipelinePolicy(LlamaPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers import LlamaForCausalLM - - policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm - new_item = { - LlamaForCausalLM: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) - ) - ] - ) - } - policy.update(new_item) - - if self.pipeline_stage_manager: - # set None as default - self.set_pipeline_forward( - model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy - ) - - return policy - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - stage_manager = self.pipeline_stage_manager - held_layers = super().get_held_layers() - if stage_manager.is_first_stage(): - held_layers.append(self.model.lm_head) - return held_layers diff --git a/colossalai/inference/pipeline/utils.py b/colossalai/inference/pipeline/utils.py deleted file mode 100644 index c26aa4e40..000000000 --- a/colossalai/inference/pipeline/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Set - -import torch.nn as nn - -from colossalai.shardformer._utils import getattr_, setattr_ - - -def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None: - """ - Set all parameters and buffers of model to None - - Args: - model (nn.Module): The model to set - """ - for module_suffix in include: - set_module = getattr_(model, module_suffix) - for n, p in set_module.named_parameters(): - setattr_(set_module, n, None) - for n, buf in set_module.named_buffers(): - setattr_(set_module, n, None) - setattr_(model, module_suffix, None) - - -def get_suffix_name(suffix: str, name: str): - """ - Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit, - and 'name' when `suffix` is empty. - - Args: - suffix (str): The suffix of the suffix module - name (str): The name of the current module - """ - point = "" if suffix is "" else "." - suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}" - return suffix_name diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index de150311c..f707a86df 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -2,9 +2,11 @@ from dataclasses import dataclass import torch +from transformers.tokenization_utils_base import BatchEncoding from .kvcache_manager import MemoryManager + # adapted from: lightllm/server/router/model_infer/infer_batch.py @dataclass class BatchInferState: @@ -55,3 +57,62 @@ class BatchInferState: ] start_index += cur_seq_len return + + @classmethod + def init_from_batch( + cls, + batch: torch.Tensor, + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + ): + if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)): + raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state") + + input_ids_list = None + attention_mask = None + + if isinstance(batch, (BatchEncoding, dict)): + input_ids_list = batch["input_ids"] + attention_mask = batch["attention_mask"] + else: + input_ids_list = batch + if isinstance(input_ids_list[0], int): # for a single input + input_ids_list = [input_ids_list] + attention_mask = [attention_mask] if attention_mask is not None else attention_mask + + batch_size = len(input_ids_list) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + + max_len_in_batch = -1 + if isinstance(batch, (BatchEncoding, dict)): + for i, attn_mask in enumerate(attention_mask): + curr_seq_len = len(attn_mask) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + else: + length = max(len(input_id) for input_id in input_ids_list) + for i, input_ids in enumerate(input_ids_list): + curr_seq_len = length + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda") + + return cls( + batch_size=batch_size, + max_len_in_batch=max_len_in_batch, + seq_len=seq_lengths.to("cuda"), + start_loc=seq_start_indexes.to("cuda"), + block_loc=block_loc, + decode_layer_id=0, + past_key_values_len=0, + is_context_stage=True, + cache_manager=cache_manager, + ) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 1f4bbe9f8..db02dab59 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -93,9 +93,9 @@ class GenerateSchedule(PipelineSchedule): Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` """ - model_inputs = ( - {"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None - ) + model_inputs = { + 'infer_state': self.mb_manager.cur_descrption.infer_state + } return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): @@ -108,9 +108,8 @@ class GenerateSchedule(PipelineSchedule): dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` """ new_mask = self.mb_manager.cur_descrption.attn_mask - past_key_values = self.mb_manager.cur_descrption.kv_cache - return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values) + return dict(input_ids=new_token, attention_mask=new_mask) def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: last_hidden_state = hidden_state[:, -1] @@ -128,27 +127,38 @@ class GenerateSchedule(PipelineSchedule): return self.comm.p2p_recv() return self.comm.recv_forward() - def _load_stage_action(self, model: Module) -> None: + def _init_infer_state_action(self) -> None: """ - In this action, 1.load micro_batch 2.do the forward 3.step to update + This action is only for no first stage, to load batch and init infer_state. + 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state """ inputs_dict = self.load_micro_batch() + self.mb_manager.add_descrption(inputs_dict) + + def _load_stage_action(self, model: Module) -> None: + """ + This action is only for first stage, load, init and do forward. + 1.load micro_batch 2.do the forward 3.step to update + """ + inputs_dict = self.load_micro_batch() + self.mb_manager.add_descrption(inputs_dict) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - output_dict = model_forward(model, inputs_dict, None) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) - self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] def _gen_token_action(self, model: Module): """ - In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update + This action is only for first stage + 1.do the forward with hidden_states to generate new tokens 2.step to update """ hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" - hidden_states = {"hidden_states": hidden_states} - logits = model_forward(model, None, hidden_states) + interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} + logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) @@ -157,7 +167,7 @@ class GenerateSchedule(PipelineSchedule): ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits["logits"]) - self.mb_manager.step(None, None, new_token) + self.mb_manager.step(new_token) self.action_interval_buffer.new_token = new_token self.action_interval_buffer.hidden_states = None @@ -168,20 +178,18 @@ class GenerateSchedule(PipelineSchedule): new_token = self.action_interval_buffer.new_token assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" inputs_dict = self._prepare_inputs_for_new_token(new_token) - output_dict = model_forward(model, inputs_dict, None) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) - self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" - inputs_dict = self._prepare_inputs_for_interval_stage() - hidden_states = {"hidden_states": hidden_states} - output_dict = model_forward(model, inputs_dict, hidden_states) + interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, None, interval_inputs) - self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ @@ -218,6 +226,8 @@ class GenerateSchedule(PipelineSchedule): actions.append(partial(self._gen_token_action, model)) # other stage else: + if self.mb_manager.cur_state is Status.PREFILL: + actions.append(partial(self._init_infer_state_action)) actions.append(partial(self._comm_action, True)) actions.append(partial(self._body_encoding_action, model)) @@ -308,8 +318,9 @@ class GenerateSchedule(PipelineSchedule): if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - output_dict = model_forward(model, inputs_dict, None) - self.mb_manager.step(inputs_dict, output_dict, None) + self.mb_manager.add_descrption(inputs_dict) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) # In GENERATE phase else: # Get hidden_states from previous stage @@ -319,25 +330,28 @@ class GenerateSchedule(PipelineSchedule): assert ( hidden_states is not None ), "When first stage in GENERATE phase, the hidden states should not be None" - logits = model_forward(model, None, hidden_states) + interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} + logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert ( - "logits" in logits - ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits["logits"]) - self.mb_manager.step(None, None, new_token) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits['logits']) + self.mb_manager.step(new_token) # If the current micro batch is not DONE, go through blocks if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): inputs_dict = self._prepare_inputs_for_new_token(new_token) - output_dict = model_forward(model, inputs_dict, None) - self.mb_manager.step(inputs_dict, output_dict, None) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) else: assert hidden_states is not None, "When not first stage, the hidden states should not be None" - inputs_dict = self._prepare_inputs_for_interval_stage() - output_dict = model_forward(model, inputs_dict, hidden_states) - self.mb_manager.step(inputs_dict, output_dict, None) + # inputs_dict = self._prepare_inputs_for_interval_stage() + inputs_dict = None + if self.mb_manager.cur_state is Status.PREFILL: + inputs_dict = self.load_micro_batch() + self.mb_manager.add_descrption(inputs_dict) + interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) # Current microbatch is not DONE, send hidden_state to next stage if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in ( diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index a285874d2..2aa613983 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -76,4 +76,5 @@ class ShardConfig: """ Set default params for inference. """ - assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" + # assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" + pass diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index ad8e32b48..6d02f2b32 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -2,12 +2,15 @@ import pytest import torch import torch.distributed as dist import transformers +from packaging import version import colossalai -from colossalai.inference.pipeline.engine import PPInferEngine -from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy +from colossalai.inference.pipeline import PPInferEngine +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + def data_gen(): input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) @@ -24,20 +27,21 @@ for k, v in inputs.items(): def pipeline_inference_test(pp_size, new_length, micro_batch_size): - model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8)) + model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4)) + engine = PPInferEngine( pp_size=pp_size, model=model, - model_policy=GPT2LMHeadModelPipelinePolicy(), + model_policy=LlamaModelInferPolicy(), new_length=new_length, micro_batch_size=micro_batch_size, ) - output = engine.inference([inputs]) + output = engine.inference(inputs) if dist.get_rank() == 0: assert len(output[0]) == new_length, f"{len(output)}, {new_length}" -@parameterize("pp_size", [4]) +@parameterize("pp_size", [2]) @parameterize("new_length", [4, 8, 16]) @parameterize("micro_batch_size", [1, 4]) @clear_cache_before_run() @@ -51,11 +55,12 @@ def check_pipeline_inference(rank, world_size, port): run_pipeline_inference_test() +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_pipeline_inference(): - spawn(check_pipeline_inference, nprocs=4) + spawn(check_pipeline_inference, nprocs=2) if __name__ == "__main__":