Browse Source

[inference] decouple pp logic for llama (#5092)

* [example] update inference benchmark

* [inference] decouple pp logic for llama

* [inference] update examples
pull/5076/head^2
Hongxin Liu 1 year ago committed by GitHub
parent
commit
27e62ba0f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 110
      colossalai/inference/engine/engine.py
  2. 67
      colossalai/inference/engine/modeling/llama.py
  3. 3
      colossalai/inference/kv_cache/kvcache_manager.py
  4. 47
      examples/inference/benchmark_llama.py
  5. 3
      examples/inference/example.py

110
colossalai/inference/engine/engine.py

@ -1,8 +1,9 @@
from typing import Union
from typing import Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from transformers.generation import GenerationConfig
from transformers.utils import logging
from colossalai.cluster import ProcessGroupMesh
@ -11,7 +12,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from ..kv_cache import MemoryManager
from ..kv_cache import BatchInferState, MemoryManager
from .microbatch_manager import MicroBatchManager
from .policies import model_policy_map
@ -31,10 +32,10 @@ class InferenceEngine:
InferenceEngine is a class that handles the pipeline parallel inference.
Args:
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
tp_size (int): the size of tensor parallelism.
pp_size (int): the size of pipeline parallelism.
dtype (str): the data type of the model, should be one of 'fp16', 'fp32', 'bf16'.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. It will be determined by the model type if not provided.
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
@ -48,10 +49,10 @@ class InferenceEngine:
def __init__(
self,
model: nn.Module,
tp_size: int = 1,
pp_size: int = 1,
dtype: str = "fp16",
model: nn.Module = None,
model_policy: Policy = None,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
@ -65,6 +66,14 @@ class InferenceEngine:
do_sample: bool = False,
num_beams: int = 1,
) -> None:
# sanity check
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
assert (
tp_size * pp_size == dist.get_world_size()
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
if quant == "gptq":
from ..quant.gptq import GPTQManager
@ -73,19 +82,12 @@ class InferenceEngine:
elif quant == "smoothquant":
model = model.model
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
assert (
tp_size * pp_size == dist.get_world_size()
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert model, "Model should be provided."
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
self.pp_size = pp_size
self.tp_size = tp_size
self.quant = quant
self.max_input_len = max_input_len
self.max_batch_size = max_batch_size
self.max_output_len = max_output_len
logger = logging.get_logger(__name__)
if quant == "smoothquant" and dtype != "fp32":
@ -104,32 +106,34 @@ class InferenceEngine:
if model_policy is None:
model_policy = model_policy_map[model.config.model_type]()
# Init pg mesh
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True if pp_size * tp_size > 1 else False)
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
stage_manager.stage,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
self.model = self._shardformer(
model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS) if pp_size * tp_size > 1 else None
)
# Init pg mesh
self.pg_mesh = ProcessGroupMesh(pp_size, tp_size)
stage_manager = None
if pp_size > 1:
stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS, True)
mb_manager = MicroBatchManager(
stage_manager.stage,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.schedule = GenerateSchedule(stage_manager, mb_manager, verbose)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if tp_size > 1 else None
self.model = self._shardformer(model, model_policy, stage_manager, self.tp_group)
if quant == "gptq":
self.gptq_manager.post_init_gptq_buffer(self.model)
self.verbose = verbose
def generate(self, input_list: Union[list, dict]):
def generate(self, input_list: Union[list, dict], generation_config: Optional[GenerationConfig] = None):
"""
Args:
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
@ -139,13 +143,38 @@ class InferenceEngine:
timestamp (float): the time cost of the inference, only return when verbose is `True`.
"""
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
if self.verbose:
return out, timestamp
if self.pp_size > 1:
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
if self.verbose:
return out, timestamp
else:
return out
else:
# when pipeline parallelism is not used, we can directly use the model to generate
# now the size if cache manager list is 1
batch_infer_state = BatchInferState.init_from_batch(
input_list, self.max_input_len, self.max_output_len, self.cache_manager_list[0]
)
# bind the infer state to the model (not lm model)
self.model.model.infer_state = batch_infer_state
if generation_config is not None:
generation_config.max_new_tokens = self.max_output_len
else:
generation_config = GenerationConfig(
max_new_tokens=self.max_output_len, pad_token_id=self.model.config.pad_token_id
)
out = self.model.generate(**input_list, generation_config=generation_config)
# free the cache
self.cache_manager_list[0].free_all()
return out
def _shardformer(self, model, model_policy, stage_manager, tp_group):
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
stage_manager: Optional[PipelineStageManager],
tp_group: Optional[dist.ProcessGroup],
) -> nn.Module:
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
@ -161,7 +190,7 @@ class InferenceEngine:
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> MemoryManager:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
if model.config.model_type == "llama":
head_dim = model.config.hidden_size // model.config.num_attention_heads
@ -188,8 +217,5 @@ class InferenceEngine:
else:
raise NotImplementedError("Only support llama, bloom and chatglm model.")
if self.quant == "smoothquant":
cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
else:
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
return cache_manager
dtype = torch.int8 if self.quant == "smoothquant" else self.dtype
return MemoryManager(max_total_token_num, dtype, head_num, head_dim, layer_num)

67
colossalai/inference/engine/modeling/llama.py

@ -3,6 +3,7 @@ import math
from typing import List, Optional, Tuple
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
from transformers.utils import logging
@ -29,13 +30,17 @@ except:
try:
from colossalai.kernel.triton.flash_decoding import token_flash_decoding
HAS_TRITON_FLASH_DECODING_KERNEL = True
except:
print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
print(
"no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8"
)
HAS_TRITON_FLASH_DECODING_KERNEL = False
try:
from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
@ -48,6 +53,7 @@ def rotate_half(x):
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
@ -96,17 +102,22 @@ def llama_triton_context_attention(
infer_state.max_len_in_batch,
)
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
def llama_triton_token_attention(
query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num=-1, head_dim=-1
):
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
token_flash_decoding(q = query_states,
o_tensor = attn_output,
infer_state = infer_state,
q_head_num = q_head_num,
head_dim = head_dim,
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id])
return
token_flash_decoding(
q=query_states,
o_tensor=attn_output,
infer_state=infer_state,
q_head_num=q_head_num,
head_dim=head_dim,
cache_k=infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
cache_v=infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
)
return
if num_key_value_groups == 1:
token_attention_fwd(
query_states,
@ -157,6 +168,7 @@ class LlamaInferenceForwards:
stage_index: Optional[List[int]] = None,
):
r"""
This function is only used when pipeline is enabled.
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -217,6 +229,8 @@ class LlamaInferenceForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
"""This function is always used."""
infer_state = infer_state or getattr(self, "infer_state", None)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -307,10 +321,14 @@ class LlamaInferenceForwards:
# decoder layers
infer_state.decode_layer_id = 0
if stage_index is None:
stage_index = (0, len(self.layers))
start_idx, end_idx = stage_index[0], stage_index[1]
if past_key_values is None:
past_key_values = tuple([None] * (end_idx - start_idx + 1))
# for HF api compatibility, kv-cache must be returned
next_decoder_cache = () if use_cache else None
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
decoder_layer = self.layers[idx]
# NOTE: modify here for passing args to decoder layer
@ -325,8 +343,10 @@ class LlamaInferenceForwards:
)
infer_state.decode_layer_id += 1
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
if stage_manager is None or stage_manager.is_last_stage() or stage_manager.num_stages == 1:
hidden_states = self.norm(hidden_states)
# update indices
@ -335,6 +355,12 @@ class LlamaInferenceForwards:
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
next_cache = next_decoder_cache if use_cache else None
if stage_manager is None:
if not return_dict:
return (hidden_states, next_cache)
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
return {"hidden_states": hidden_states}
@staticmethod
@ -459,14 +485,15 @@ class LlamaInferenceForwards:
)
if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states)
llama_triton_token_attention(query_states = query_states,
attn_output = attn_output,
infer_state = infer_state,
num_key_value_groups = self.num_key_value_groups,
q_head_num = q_len * self.num_heads,
head_dim = self.head_dim)
llama_triton_token_attention(
query_states=query_states,
attn_output=attn_output,
infer_state=infer_state,
num_key_value_groups=self.num_key_value_groups,
q_head_num=q_len * self.num_heads,
head_dim=self.head_dim,
)
else:
self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]

3
colossalai/inference/kv_cache/kvcache_manager.py vendored

@ -55,8 +55,7 @@ class MemoryManager:
def alloc(self, required_size):
"""allocate space of required_size by providing indexes representing available physical spaces"""
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
raise RuntimeError(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
select_index = self.indexes[select_index]

47
examples/inference/benchmark_llama.py

@ -1,5 +1,6 @@
import argparse
import time
from contextlib import nullcontext
import torch
import torch.distributed as dist
@ -106,15 +107,16 @@ def print_details_info(outputs, model_config, args, whole_end2end):
def benchmark_inference(args):
config = CONFIG_MAP[args.model]
config.pad_token_id = config.eos_token_id
model = transformers.LlamaForCausalLM(config)
if dist.get_rank() == 0:
print("Model loaded")
engine = InferenceEngine(
model,
pp_size=args.pp_size,
tp_size=args.tp_size,
dtype=args.dtype,
micro_batch_size=args.mb_size,
model=model,
verbose=args.verbose,
max_batch_size=args.batch_size,
max_input_len=args.seq_len,
@ -124,14 +126,37 @@ def benchmark_inference(args):
N_WARMUP_STEPS = 2
for _ in range(N_WARMUP_STEPS):
engine.generate(data)
ctx = (
torch.profiler.profile(
record_shapes=True,
with_stack=True,
with_modules=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler("./tb_log"),
)
if args.profile
else nullcontext()
)
torch.cuda.synchronize()
whole_end2end = time.time()
outputs = engine.generate(data)
torch.cuda.synchronize()
whole_end2end = time.time() - whole_end2end
with ctx:
for _ in range(N_WARMUP_STEPS):
engine.generate(data)
if args.profile:
ctx.step()
if args.nsys:
torch.cuda.cudart().cudaProfilerStart()
whole_end2end = time.perf_counter()
outputs = engine.generate(data)
whole_end2end = time.perf_counter() - whole_end2end
if args.nsys:
torch.cuda.cudart().cudaProfilerStop()
if args.profile:
ctx.step()
print_details_info(outputs, model.config, args, whole_end2end)
@ -157,12 +182,14 @@ if __name__ == "__main__":
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
)
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length")
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
parser.add_argument("--output_len", type=int, default=128, help="Output length")
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"])
parser.add_argument("-v", "--verbose", default=False, action="store_true")
parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler")
parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler")
args = parser.parse_args()
benchmark(args)

3
examples/inference/example.py

@ -30,9 +30,9 @@ def run_inference(args):
model = LlamaForCausalLM.from_pretrained(model_name_or_path, pad_token_id=tokenizer.pad_token_id)
engine = InferenceEngine(
model,
tp_size=tp_size,
pp_size=pp_size,
model=model,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
@ -61,7 +61,6 @@ if __name__ == "__main__":
parser.add_argument(
"-m", "--model_name_or_path", type=str, help="Model name from huggingface or local path", default=None
)
parser.add_argument("-i", "--input", default="What is the longest river in the world?")
parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None)
parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")

Loading…
Cancel
Save