From 27e62ba0f7a11cf79c765575bead6ccb564964be Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 22 Nov 2023 13:53:08 +0800 Subject: [PATCH] [inference] decouple pp logic for llama (#5092) * [example] update inference benchmark * [inference] decouple pp logic for llama * [inference] update examples --- colossalai/inference/engine/engine.py | 110 +++++++++++------- colossalai/inference/engine/modeling/llama.py | 67 +++++++---- .../inference/kv_cache/kvcache_manager.py | 3 +- examples/inference/benchmark_llama.py | 47 ++++++-- examples/inference/example.py | 3 +- 5 files changed, 154 insertions(+), 76 deletions(-) diff --git a/colossalai/inference/engine/engine.py b/colossalai/inference/engine/engine.py index 61da5858a..aeb637b96 100644 --- a/colossalai/inference/engine/engine.py +++ b/colossalai/inference/engine/engine.py @@ -1,8 +1,9 @@ -from typing import Union +from typing import Optional, Union import torch import torch.distributed as dist import torch.nn as nn +from transformers.generation import GenerationConfig from transformers.utils import logging from colossalai.cluster import ProcessGroupMesh @@ -11,7 +12,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy -from ..kv_cache import MemoryManager +from ..kv_cache import BatchInferState, MemoryManager from .microbatch_manager import MicroBatchManager from .policies import model_policy_map @@ -31,10 +32,10 @@ class InferenceEngine: InferenceEngine is a class that handles the pipeline parallel inference. Args: + model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. tp_size (int): the size of tensor parallelism. pp_size (int): the size of pipeline parallelism. dtype (str): the data type of the model, should be one of 'fp16', 'fp32', 'bf16'. - model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. It will be determined by the model type if not provided. micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. @@ -48,10 +49,10 @@ class InferenceEngine: def __init__( self, + model: nn.Module, tp_size: int = 1, pp_size: int = 1, dtype: str = "fp16", - model: nn.Module = None, model_policy: Policy = None, micro_batch_size: int = 1, micro_batch_buffer_size: int = None, @@ -65,6 +66,14 @@ class InferenceEngine: do_sample: bool = False, num_beams: int = 1, ) -> None: + # sanity check + assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported." + assert ( + tp_size * pp_size == dist.get_world_size() + ), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})" + assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" + assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'" + if quant == "gptq": from ..quant.gptq import GPTQManager @@ -73,19 +82,12 @@ class InferenceEngine: elif quant == "smoothquant": model = model.model - assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported." - assert ( - tp_size * pp_size == dist.get_world_size() - ), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})" - assert model, "Model should be provided." - assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" - - assert max_batch_size <= 64, "Max batch size exceeds the constraint" - assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint" - assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'" self.pp_size = pp_size self.tp_size = tp_size self.quant = quant + self.max_input_len = max_input_len + self.max_batch_size = max_batch_size + self.max_output_len = max_output_len logger = logging.get_logger(__name__) if quant == "smoothquant" and dtype != "fp32": @@ -104,32 +106,34 @@ class InferenceEngine: if model_policy is None: model_policy = model_policy_map[model.config.model_type]() - # Init pg mesh - pg_mesh = ProcessGroupMesh(pp_size, tp_size) - - stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True if pp_size * tp_size > 1 else False) self.cache_manager_list = [ self._init_manager(model, max_batch_size, max_input_len, max_output_len) for _ in range(micro_batch_buffer_size or pp_size) ] - self.mb_manager = MicroBatchManager( - stage_manager.stage, - micro_batch_size, - micro_batch_buffer_size or pp_size, - max_input_len, - max_output_len, - self.cache_manager_list, - ) - self.verbose = verbose - self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose) - self.model = self._shardformer( - model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS) if pp_size * tp_size > 1 else None - ) + # Init pg mesh + self.pg_mesh = ProcessGroupMesh(pp_size, tp_size) + stage_manager = None + if pp_size > 1: + stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS, True) + mb_manager = MicroBatchManager( + stage_manager.stage, + micro_batch_size, + micro_batch_buffer_size or pp_size, + max_input_len, + max_output_len, + self.cache_manager_list, + ) + self.schedule = GenerateSchedule(stage_manager, mb_manager, verbose) + + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if tp_size > 1 else None + + self.model = self._shardformer(model, model_policy, stage_manager, self.tp_group) if quant == "gptq": self.gptq_manager.post_init_gptq_buffer(self.model) + self.verbose = verbose - def generate(self, input_list: Union[list, dict]): + def generate(self, input_list: Union[list, dict], generation_config: Optional[GenerationConfig] = None): """ Args: input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`. @@ -139,13 +143,38 @@ class InferenceEngine: timestamp (float): the time cost of the inference, only return when verbose is `True`. """ - out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) - if self.verbose: - return out, timestamp + if self.pp_size > 1: + out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) + if self.verbose: + return out, timestamp + else: + return out else: + # when pipeline parallelism is not used, we can directly use the model to generate + # now the size if cache manager list is 1 + batch_infer_state = BatchInferState.init_from_batch( + input_list, self.max_input_len, self.max_output_len, self.cache_manager_list[0] + ) + # bind the infer state to the model (not lm model) + self.model.model.infer_state = batch_infer_state + if generation_config is not None: + generation_config.max_new_tokens = self.max_output_len + else: + generation_config = GenerationConfig( + max_new_tokens=self.max_output_len, pad_token_id=self.model.config.pad_token_id + ) + out = self.model.generate(**input_list, generation_config=generation_config) + # free the cache + self.cache_manager_list[0].free_all() return out - def _shardformer(self, model, model_policy, stage_manager, tp_group): + def _shardformer( + self, + model: nn.Module, + model_policy: Policy, + stage_manager: Optional[PipelineStageManager], + tp_group: Optional[dist.ProcessGroup], + ) -> nn.Module: shardconfig = ShardConfig( tensor_parallel_process_group=tp_group, pipeline_stage_manager=stage_manager, @@ -161,7 +190,7 @@ class InferenceEngine: shard_model, _ = shardformer.optimize(model, model_policy) return shard_model.cuda() - def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: + def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> MemoryManager: max_total_token_num = max_batch_size * (max_input_len + max_output_len) if model.config.model_type == "llama": head_dim = model.config.hidden_size // model.config.num_attention_heads @@ -188,8 +217,5 @@ class InferenceEngine: else: raise NotImplementedError("Only support llama, bloom and chatglm model.") - if self.quant == "smoothquant": - cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) - else: - cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num) - return cache_manager + dtype = torch.int8 if self.quant == "smoothquant" else self.dtype + return MemoryManager(max_total_token_num, dtype, head_num, head_dim, layer_num) diff --git a/colossalai/inference/engine/modeling/llama.py b/colossalai/inference/engine/modeling/llama.py index b7bc94d0e..1390f1ed0 100644 --- a/colossalai/inference/engine/modeling/llama.py +++ b/colossalai/inference/engine/modeling/llama.py @@ -3,6 +3,7 @@ import math from typing import List, Optional, Tuple import torch +from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel from transformers.utils import logging @@ -29,13 +30,17 @@ except: try: from colossalai.kernel.triton.flash_decoding import token_flash_decoding + HAS_TRITON_FLASH_DECODING_KERNEL = True except: - print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") + print( + "no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8" + ) HAS_TRITON_FLASH_DECODING_KERNEL = False - + try: from flash_attn import flash_attn_with_kvcache + HAS_FLASH_KERNEL = True except: HAS_FLASH_KERNEL = False @@ -48,6 +53,7 @@ def rotate_half(x): x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] @@ -96,17 +102,22 @@ def llama_triton_context_attention( infer_state.max_len_in_batch, ) -def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1): + +def llama_triton_token_attention( + query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num=-1, head_dim=-1 +): if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1: - token_flash_decoding(q = query_states, - o_tensor = attn_output, - infer_state = infer_state, - q_head_num = q_head_num, - head_dim = head_dim, - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]) - return - + token_flash_decoding( + q=query_states, + o_tensor=attn_output, + infer_state=infer_state, + q_head_num=q_head_num, + head_dim=head_dim, + cache_k=infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + cache_v=infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + ) + return + if num_key_value_groups == 1: token_attention_fwd( query_states, @@ -157,6 +168,7 @@ class LlamaInferenceForwards: stage_index: Optional[List[int]] = None, ): r""" + This function is only used when pipeline is enabled. Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -217,6 +229,8 @@ class LlamaInferenceForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, ): + """This function is always used.""" + infer_state = infer_state or getattr(self, "infer_state", None) return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -307,10 +321,14 @@ class LlamaInferenceForwards: # decoder layers infer_state.decode_layer_id = 0 + if stage_index is None: + stage_index = (0, len(self.layers)) start_idx, end_idx = stage_index[0], stage_index[1] if past_key_values is None: past_key_values = tuple([None] * (end_idx - start_idx + 1)) + # for HF api compatibility, kv-cache must be returned + next_decoder_cache = () if use_cache else None for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): decoder_layer = self.layers[idx] # NOTE: modify here for passing args to decoder layer @@ -325,8 +343,10 @@ class LlamaInferenceForwards: ) infer_state.decode_layer_id += 1 hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if stage_manager.is_last_stage() or stage_manager.num_stages == 1: + if stage_manager is None or stage_manager.is_last_stage() or stage_manager.num_stages == 1: hidden_states = self.norm(hidden_states) # update indices @@ -335,6 +355,12 @@ class LlamaInferenceForwards: infer_state.seq_len += 1 infer_state.max_len_in_batch += 1 + next_cache = next_decoder_cache if use_cache else None + if stage_manager is None: + if not return_dict: + return (hidden_states, next_cache) + return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache) + return {"hidden_states": hidden_states} @staticmethod @@ -459,14 +485,15 @@ class LlamaInferenceForwards: ) if HAS_LIGHTLLM_KERNEL: - attn_output = torch.empty_like(query_states) - llama_triton_token_attention(query_states = query_states, - attn_output = attn_output, - infer_state = infer_state, - num_key_value_groups = self.num_key_value_groups, - q_head_num = q_len * self.num_heads, - head_dim = self.head_dim) + llama_triton_token_attention( + query_states=query_states, + attn_output=attn_output, + infer_state=infer_state, + num_key_value_groups=self.num_key_value_groups, + q_head_num=q_len * self.num_heads, + head_dim=self.head_dim, + ) else: self.num_heads // self.num_key_value_heads cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index dda46a756..c11c4e5d0 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -55,8 +55,7 @@ class MemoryManager: def alloc(self, required_size): """allocate space of required_size by providing indexes representing available physical spaces""" if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") - return None + raise RuntimeError(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) select_index = self.indexes[select_index] diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 9a26098b3..30ea1eca8 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -1,5 +1,6 @@ import argparse import time +from contextlib import nullcontext import torch import torch.distributed as dist @@ -106,15 +107,16 @@ def print_details_info(outputs, model_config, args, whole_end2end): def benchmark_inference(args): config = CONFIG_MAP[args.model] + config.pad_token_id = config.eos_token_id model = transformers.LlamaForCausalLM(config) if dist.get_rank() == 0: print("Model loaded") engine = InferenceEngine( + model, pp_size=args.pp_size, tp_size=args.tp_size, dtype=args.dtype, micro_batch_size=args.mb_size, - model=model, verbose=args.verbose, max_batch_size=args.batch_size, max_input_len=args.seq_len, @@ -124,14 +126,37 @@ def benchmark_inference(args): N_WARMUP_STEPS = 2 - for _ in range(N_WARMUP_STEPS): - engine.generate(data) + ctx = ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler("./tb_log"), + ) + if args.profile + else nullcontext() + ) - torch.cuda.synchronize() - whole_end2end = time.time() - outputs = engine.generate(data) - torch.cuda.synchronize() - whole_end2end = time.time() - whole_end2end + with ctx: + for _ in range(N_WARMUP_STEPS): + engine.generate(data) + if args.profile: + ctx.step() + + if args.nsys: + torch.cuda.cudart().cudaProfilerStart() + whole_end2end = time.perf_counter() + outputs = engine.generate(data) + whole_end2end = time.perf_counter() - whole_end2end + if args.nsys: + torch.cuda.cudart().cudaProfilerStop() + if args.profile: + ctx.step() print_details_info(outputs, model.config, args, whole_end2end) @@ -157,12 +182,14 @@ if __name__ == "__main__": choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"], ) parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") - parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") + parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length") parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") parser.add_argument("--pp_size", type=int, default=1, help="pipeline size") parser.add_argument("--tp_size", type=int, default=1, help="pipeline size") parser.add_argument("--output_len", type=int, default=128, help="Output length") - parser.add_argument("--dtype", type=str, default="fp16", help="data type") + parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"]) parser.add_argument("-v", "--verbose", default=False, action="store_true") + parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler") + parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler") args = parser.parse_args() benchmark(args) diff --git a/examples/inference/example.py b/examples/inference/example.py index ff58664a3..2e541ff1f 100644 --- a/examples/inference/example.py +++ b/examples/inference/example.py @@ -30,9 +30,9 @@ def run_inference(args): model = LlamaForCausalLM.from_pretrained(model_name_or_path, pad_token_id=tokenizer.pad_token_id) engine = InferenceEngine( + model, tp_size=tp_size, pp_size=pp_size, - model=model, max_input_len=max_input_len, max_output_len=max_output_len, max_batch_size=max_batch_size, @@ -61,7 +61,6 @@ if __name__ == "__main__": parser.add_argument( "-m", "--model_name_or_path", type=str, help="Model name from huggingface or local path", default=None ) - parser.add_argument("-i", "--input", default="What is the longest river in the world?") parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None) parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")