mirror of https://github.com/hpcaitech/ColossalAI
[inference] decouple pp logic for llama (#5092)
* [example] update inference benchmark * [inference] decouple pp logic for llama * [inference] update examplespull/5076/head^2
parent
79c4bff452
commit
27e62ba0f7
|
@ -1,8 +1,9 @@
|
|||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from transformers.generation import GenerationConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
|
@ -11,7 +12,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from ..kv_cache import MemoryManager
|
||||
from ..kv_cache import BatchInferState, MemoryManager
|
||||
from .microbatch_manager import MicroBatchManager
|
||||
from .policies import model_policy_map
|
||||
|
||||
|
@ -31,10 +32,10 @@ class InferenceEngine:
|
|||
InferenceEngine is a class that handles the pipeline parallel inference.
|
||||
|
||||
Args:
|
||||
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
|
||||
tp_size (int): the size of tensor parallelism.
|
||||
pp_size (int): the size of pipeline parallelism.
|
||||
dtype (str): the data type of the model, should be one of 'fp16', 'fp32', 'bf16'.
|
||||
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
|
||||
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
|
@ -48,10 +49,10 @@ class InferenceEngine:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
tp_size: int = 1,
|
||||
pp_size: int = 1,
|
||||
dtype: str = "fp16",
|
||||
model: nn.Module = None,
|
||||
model_policy: Policy = None,
|
||||
micro_batch_size: int = 1,
|
||||
micro_batch_buffer_size: int = None,
|
||||
|
@ -65,6 +66,14 @@ class InferenceEngine:
|
|||
do_sample: bool = False,
|
||||
num_beams: int = 1,
|
||||
) -> None:
|
||||
# sanity check
|
||||
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
|
||||
assert (
|
||||
tp_size * pp_size == dist.get_world_size()
|
||||
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
||||
assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
||||
|
||||
if quant == "gptq":
|
||||
from ..quant.gptq import GPTQManager
|
||||
|
||||
|
@ -73,19 +82,12 @@ class InferenceEngine:
|
|||
elif quant == "smoothquant":
|
||||
model = model.model
|
||||
|
||||
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
|
||||
assert (
|
||||
tp_size * pp_size == dist.get_world_size()
|
||||
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
assert model, "Model should be provided."
|
||||
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
||||
|
||||
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
|
||||
assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
||||
self.pp_size = pp_size
|
||||
self.tp_size = tp_size
|
||||
self.quant = quant
|
||||
self.max_input_len = max_input_len
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_output_len = max_output_len
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
if quant == "smoothquant" and dtype != "fp32":
|
||||
|
@ -104,32 +106,34 @@ class InferenceEngine:
|
|||
if model_policy is None:
|
||||
model_policy = model_policy_map[model.config.model_type]()
|
||||
|
||||
# Init pg mesh
|
||||
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True if pp_size * tp_size > 1 else False)
|
||||
self.cache_manager_list = [
|
||||
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
|
||||
for _ in range(micro_batch_buffer_size or pp_size)
|
||||
]
|
||||
self.mb_manager = MicroBatchManager(
|
||||
stage_manager.stage,
|
||||
micro_batch_size,
|
||||
micro_batch_buffer_size or pp_size,
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
self.cache_manager_list,
|
||||
)
|
||||
self.verbose = verbose
|
||||
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS) if pp_size * tp_size > 1 else None
|
||||
)
|
||||
# Init pg mesh
|
||||
self.pg_mesh = ProcessGroupMesh(pp_size, tp_size)
|
||||
stage_manager = None
|
||||
if pp_size > 1:
|
||||
stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS, True)
|
||||
mb_manager = MicroBatchManager(
|
||||
stage_manager.stage,
|
||||
micro_batch_size,
|
||||
micro_batch_buffer_size or pp_size,
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
self.cache_manager_list,
|
||||
)
|
||||
self.schedule = GenerateSchedule(stage_manager, mb_manager, verbose)
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if tp_size > 1 else None
|
||||
|
||||
self.model = self._shardformer(model, model_policy, stage_manager, self.tp_group)
|
||||
if quant == "gptq":
|
||||
self.gptq_manager.post_init_gptq_buffer(self.model)
|
||||
self.verbose = verbose
|
||||
|
||||
def generate(self, input_list: Union[list, dict]):
|
||||
def generate(self, input_list: Union[list, dict], generation_config: Optional[GenerationConfig] = None):
|
||||
"""
|
||||
Args:
|
||||
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
|
||||
|
@ -139,13 +143,38 @@ class InferenceEngine:
|
|||
timestamp (float): the time cost of the inference, only return when verbose is `True`.
|
||||
"""
|
||||
|
||||
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
|
||||
if self.verbose:
|
||||
return out, timestamp
|
||||
if self.pp_size > 1:
|
||||
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
|
||||
if self.verbose:
|
||||
return out, timestamp
|
||||
else:
|
||||
return out
|
||||
else:
|
||||
# when pipeline parallelism is not used, we can directly use the model to generate
|
||||
# now the size if cache manager list is 1
|
||||
batch_infer_state = BatchInferState.init_from_batch(
|
||||
input_list, self.max_input_len, self.max_output_len, self.cache_manager_list[0]
|
||||
)
|
||||
# bind the infer state to the model (not lm model)
|
||||
self.model.model.infer_state = batch_infer_state
|
||||
if generation_config is not None:
|
||||
generation_config.max_new_tokens = self.max_output_len
|
||||
else:
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=self.max_output_len, pad_token_id=self.model.config.pad_token_id
|
||||
)
|
||||
out = self.model.generate(**input_list, generation_config=generation_config)
|
||||
# free the cache
|
||||
self.cache_manager_list[0].free_all()
|
||||
return out
|
||||
|
||||
def _shardformer(self, model, model_policy, stage_manager, tp_group):
|
||||
def _shardformer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
stage_manager: Optional[PipelineStageManager],
|
||||
tp_group: Optional[dist.ProcessGroup],
|
||||
) -> nn.Module:
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
|
@ -161,7 +190,7 @@ class InferenceEngine:
|
|||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model.cuda()
|
||||
|
||||
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
|
||||
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> MemoryManager:
|
||||
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
|
||||
if model.config.model_type == "llama":
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
|
@ -188,8 +217,5 @@ class InferenceEngine:
|
|||
else:
|
||||
raise NotImplementedError("Only support llama, bloom and chatglm model.")
|
||||
|
||||
if self.quant == "smoothquant":
|
||||
cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
|
||||
else:
|
||||
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
|
||||
return cache_manager
|
||||
dtype = torch.int8 if self.quant == "smoothquant" else self.dtype
|
||||
return MemoryManager(max_total_token_num, dtype, head_num, head_dim, layer_num)
|
||||
|
|
|
@ -3,6 +3,7 @@ import math
|
|||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
@ -29,13 +30,17 @@ except:
|
|||
|
||||
try:
|
||||
from colossalai.kernel.triton.flash_decoding import token_flash_decoding
|
||||
|
||||
HAS_TRITON_FLASH_DECODING_KERNEL = True
|
||||
except:
|
||||
print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
|
||||
print(
|
||||
"no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8"
|
||||
)
|
||||
HAS_TRITON_FLASH_DECODING_KERNEL = False
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
|
||||
HAS_FLASH_KERNEL = True
|
||||
except:
|
||||
HAS_FLASH_KERNEL = False
|
||||
|
@ -48,6 +53,7 @@ def rotate_half(x):
|
|||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
|
@ -96,17 +102,22 @@ def llama_triton_context_attention(
|
|||
infer_state.max_len_in_batch,
|
||||
)
|
||||
|
||||
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
|
||||
|
||||
def llama_triton_token_attention(
|
||||
query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num=-1, head_dim=-1
|
||||
):
|
||||
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
|
||||
token_flash_decoding(q = query_states,
|
||||
o_tensor = attn_output,
|
||||
infer_state = infer_state,
|
||||
q_head_num = q_head_num,
|
||||
head_dim = head_dim,
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id])
|
||||
return
|
||||
|
||||
token_flash_decoding(
|
||||
q=query_states,
|
||||
o_tensor=attn_output,
|
||||
infer_state=infer_state,
|
||||
q_head_num=q_head_num,
|
||||
head_dim=head_dim,
|
||||
cache_k=infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
cache_v=infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
)
|
||||
return
|
||||
|
||||
if num_key_value_groups == 1:
|
||||
token_attention_fwd(
|
||||
query_states,
|
||||
|
@ -157,6 +168,7 @@ class LlamaInferenceForwards:
|
|||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
r"""
|
||||
This function is only used when pipeline is enabled.
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
|
@ -217,6 +229,8 @@ class LlamaInferenceForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
"""This function is always used."""
|
||||
infer_state = infer_state or getattr(self, "infer_state", None)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
@ -307,10 +321,14 @@ class LlamaInferenceForwards:
|
|||
# decoder layers
|
||||
infer_state.decode_layer_id = 0
|
||||
|
||||
if stage_index is None:
|
||||
stage_index = (0, len(self.layers))
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * (end_idx - start_idx + 1))
|
||||
|
||||
# for HF api compatibility, kv-cache must be returned
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
|
||||
decoder_layer = self.layers[idx]
|
||||
# NOTE: modify here for passing args to decoder layer
|
||||
|
@ -325,8 +343,10 @@ class LlamaInferenceForwards:
|
|||
)
|
||||
infer_state.decode_layer_id += 1
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
|
||||
if stage_manager is None or stage_manager.is_last_stage() or stage_manager.num_stages == 1:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# update indices
|
||||
|
@ -335,6 +355,12 @@ class LlamaInferenceForwards:
|
|||
infer_state.seq_len += 1
|
||||
infer_state.max_len_in_batch += 1
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if stage_manager is None:
|
||||
if not return_dict:
|
||||
return (hidden_states, next_cache)
|
||||
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
|
||||
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
|
@ -459,14 +485,15 @@ class LlamaInferenceForwards:
|
|||
)
|
||||
|
||||
if HAS_LIGHTLLM_KERNEL:
|
||||
|
||||
attn_output = torch.empty_like(query_states)
|
||||
llama_triton_token_attention(query_states = query_states,
|
||||
attn_output = attn_output,
|
||||
infer_state = infer_state,
|
||||
num_key_value_groups = self.num_key_value_groups,
|
||||
q_head_num = q_len * self.num_heads,
|
||||
head_dim = self.head_dim)
|
||||
llama_triton_token_attention(
|
||||
query_states=query_states,
|
||||
attn_output=attn_output,
|
||||
infer_state=infer_state,
|
||||
num_key_value_groups=self.num_key_value_groups,
|
||||
q_head_num=q_len * self.num_heads,
|
||||
head_dim=self.head_dim,
|
||||
)
|
||||
else:
|
||||
self.num_heads // self.num_key_value_heads
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
|
||||
|
|
|
@ -55,8 +55,7 @@ class MemoryManager:
|
|||
def alloc(self, required_size):
|
||||
"""allocate space of required_size by providing indexes representing available physical spaces"""
|
||||
if required_size > self.available_size:
|
||||
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
|
||||
return None
|
||||
raise RuntimeError(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
|
||||
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
|
||||
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
|
||||
select_index = self.indexes[select_index]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -106,15 +107,16 @@ def print_details_info(outputs, model_config, args, whole_end2end):
|
|||
|
||||
def benchmark_inference(args):
|
||||
config = CONFIG_MAP[args.model]
|
||||
config.pad_token_id = config.eos_token_id
|
||||
model = transformers.LlamaForCausalLM(config)
|
||||
if dist.get_rank() == 0:
|
||||
print("Model loaded")
|
||||
engine = InferenceEngine(
|
||||
model,
|
||||
pp_size=args.pp_size,
|
||||
tp_size=args.tp_size,
|
||||
dtype=args.dtype,
|
||||
micro_batch_size=args.mb_size,
|
||||
model=model,
|
||||
verbose=args.verbose,
|
||||
max_batch_size=args.batch_size,
|
||||
max_input_len=args.seq_len,
|
||||
|
@ -124,14 +126,37 @@ def benchmark_inference(args):
|
|||
|
||||
N_WARMUP_STEPS = 2
|
||||
|
||||
for _ in range(N_WARMUP_STEPS):
|
||||
engine.generate(data)
|
||||
ctx = (
|
||||
torch.profiler.profile(
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler("./tb_log"),
|
||||
)
|
||||
if args.profile
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
whole_end2end = time.time()
|
||||
outputs = engine.generate(data)
|
||||
torch.cuda.synchronize()
|
||||
whole_end2end = time.time() - whole_end2end
|
||||
with ctx:
|
||||
for _ in range(N_WARMUP_STEPS):
|
||||
engine.generate(data)
|
||||
if args.profile:
|
||||
ctx.step()
|
||||
|
||||
if args.nsys:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
whole_end2end = time.perf_counter()
|
||||
outputs = engine.generate(data)
|
||||
whole_end2end = time.perf_counter() - whole_end2end
|
||||
if args.nsys:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
if args.profile:
|
||||
ctx.step()
|
||||
|
||||
print_details_info(outputs, model.config, args, whole_end2end)
|
||||
|
||||
|
@ -157,12 +182,14 @@ if __name__ == "__main__":
|
|||
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
|
||||
)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
||||
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
|
||||
parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length")
|
||||
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
|
||||
parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
|
||||
parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
|
||||
parser.add_argument("--output_len", type=int, default=128, help="Output length")
|
||||
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
|
||||
parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"])
|
||||
parser.add_argument("-v", "--verbose", default=False, action="store_true")
|
||||
parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler")
|
||||
parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler")
|
||||
args = parser.parse_args()
|
||||
benchmark(args)
|
||||
|
|
|
@ -30,9 +30,9 @@ def run_inference(args):
|
|||
model = LlamaForCausalLM.from_pretrained(model_name_or_path, pad_token_id=tokenizer.pad_token_id)
|
||||
|
||||
engine = InferenceEngine(
|
||||
model,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
max_batch_size=max_batch_size,
|
||||
|
@ -61,7 +61,6 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"-m", "--model_name_or_path", type=str, help="Model name from huggingface or local path", default=None
|
||||
)
|
||||
parser.add_argument("-i", "--input", default="What is the longest river in the world?")
|
||||
parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None)
|
||||
parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
|
||||
|
|
Loading…
Reference in New Issue