[Pipeline Inference] Merge pp with tp (#4993)

* refactor pipeline into new CaiInferEngine

* updata llama modeling forward

* merge tp with pp

* update docstring

* optimize test workflow and example

* fix typo

* add assert and todo
pull/4998/head
Bin Jia 2023-11-01 12:46:21 +08:00 committed by GitHub
parent 335cb105e2
commit b6696beb04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 268 additions and 203 deletions

View File

@ -1,4 +1,4 @@
from .pipeline import PPInferEngine
from .hybridengine import CaiInferEngine
from .hybridengine.polices import LlamaModelInferPolicy
__all__ = ['PPInferEngine']
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]

View File

@ -0,0 +1,3 @@
from .engine import CaiInferEngine
__all__ = ["CaiInferEngine"]

View File

@ -1,4 +1,5 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding
@ -8,23 +9,27 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from ..pipeline.microbatch_manager import MicroBatchManager
from ..tensor_parallel.kvcache_manager import MemoryManager
from .microbatch_manager import MicroBatchManager
PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
]
class PPInferEngine:
class CaiInferEngine:
"""
PPInferEngine is a class that handles the pipeline parallel inference.
CaiInferEngine is a class that handles the pipeline parallel inference.
Args:
pp_size (int): the number of pipeline stages.
pp_model (`nn.Module`): the model already in pipeline parallelism style.
tp_size (int): the size of tensor parallelism.
pp_size (int): the size of pipeline parallelism.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
new_length (int): the new length of the input sequence.
early_stopping (bool): whether to stop early.
max_batch_size (int): the maximum batch size.
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.
@ -32,7 +37,7 @@ class PPInferEngine:
Example:
```python
from colossalai.inference import PPInferEngine
from colossalai.inference import InferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
@ -42,7 +47,7 @@ class PPInferEngine:
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
@ -54,12 +59,11 @@ class PPInferEngine:
def __init__(
self,
pp_size: int,
tp_size: int = 1,
pp_size: int = 1,
dtype: str = "fp16",
pp_model: nn.Module = None,
model: nn.Module = None,
model_policy: Policy = None,
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
max_batch_size: int = 4,
@ -71,12 +75,21 @@ class PPInferEngine:
do_sample: bool = False,
num_beams: int = 1,
) -> None:
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
assert (
tp_size * pp_size == dist.get_world_size()
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert model and model_policy, "Model with model_policy should be provided."
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
max_output_len = max(max_output_len, max_input_len + new_length)
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
# TODO: support only tensor parallel inference
assert pp_size > 1, "Not support only tensor parallel inference."
self.pp_size = pp_size
self.tp_size = tp_size
if dtype == "fp16":
self.dtype = torch.float16
model.half()
@ -85,24 +98,29 @@ class PPInferEngine:
model.to(torch.bfloat16)
else:
self.dtype = torch.float32
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.model = pp_model or self._shardformer(model, model_policy)
self.cache_manager_list = [
self._init_manager(max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
self.stage_manager.stage,
new_length,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
# Init pg mesh
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
stage_manager = None
if pp_size > 1:
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
stage_manager.stage,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
def inference(self, input_list):
"""
@ -124,10 +142,10 @@ class PPInferEngine:
else:
return out
def _shardformer(self, model, model_policy):
def _shardformer(self, model, model_policy, stage_manager, tp_group):
shardconfig = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=self.stage_manager,
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_all_optimization=False,
@ -139,14 +157,12 @@ class PPInferEngine:
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()
def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
head_num = self.model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads
num_hidden_layers = (
self.model.config.num_hidden_layers
if hasattr(self.model.config, "num_hidden_layers")
else self.model.config.num_layers
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size

View File

@ -1,37 +1,25 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
import math
from typing import List, Optional, Tuple
import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
from transformers.utils import logging
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import copy_kv_to_mem_cache
try:
from vllm import layernorm_ops, pos_encoding_ops
rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print(
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_context_attention_fwd,
)
HAS_VLLM_KERNERL = False
try:
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
HAS_LIGHTLLM_KERNEL = True
@ -39,6 +27,14 @@ except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False
try:
from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
@ -59,6 +55,75 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return q_embed, k_embed
def llama_triton_context_attention(
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
):
if num_key_value_groups == 1:
if HAS_LIGHTLLM_KERNEL is False:
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
lightllm_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
lightllm_llama2_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
if num_key_value_groups == 1:
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
Llama2TokenAttentionForwards.token_attn(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)
class LlamaInferenceForwards:
"""
This class holds forwards for llama inference.
@ -144,13 +209,9 @@ class LlamaInferenceForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
# batch_size = input_ids.shape[0] # input_ids.shape[0]
# print(f"[Before] rank:{torch.distributed.get_rank()}\n->{infer_state}")
# infer_state = self.infer_state
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds
if stage_manager is None or stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
@ -172,12 +233,10 @@ class LlamaInferenceForwards:
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
if infer_state.is_context_stage is False:
past_key_values_length = infer_state.cache_manager.past_key_values_length
seq_length_with_past = seq_length_with_past + past_key_values_length
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
@ -197,26 +256,19 @@ class LlamaInferenceForwards:
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
if position_ids is None:
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
new_shape = [1] * position_ids.dim()
new_shape[0] = batch_size
position_ids = position_ids.repeat(*new_shape).view(-1, seq_length)
position_ids = position_ids.repeat(batch_size, 1)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
@ -227,15 +279,17 @@ class LlamaInferenceForwards:
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
(batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device
)
attention_mask = self._prepare_decoder_attention_mask(
@ -243,10 +297,6 @@ class LlamaInferenceForwards:
)
# decoder layers
() if output_hidden_states else None
() if output_attentions else None
next_decoder_cache = () if use_cache else None
infer_state.decode_layer_id = 0
start_idx, end_idx = stage_index[0], stage_index[1]
@ -268,19 +318,15 @@ class LlamaInferenceForwards:
infer_state.decode_layer_id += 1
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if stage_manager.is_last_stage():
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
hidden_states = self.norm(hidden_states)
next_cache = next_decoder_cache if use_cache else None
# update indices
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
# TODO: fix this to necessary return
# if not return_dict:
# return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
@ -290,8 +336,7 @@ class LlamaInferenceForwards:
# hidden_states=all_hidden_states,
# attentions=all_self_attns,
# )
# print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}")
return {"hidden_states": hidden_states, "past_key_values": next_cache}
return {"hidden_states": hidden_states}
@staticmethod
def llama_decoder_layer_forward(
@ -307,7 +352,6 @@ class LlamaInferenceForwards:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
@ -357,28 +401,24 @@ class LlamaInferenceForwards:
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)
if infer_state.is_context_stage:
# print(f"rank:{torch.distributed.get_rank()}, {infer_state}")
# first token generation
# copy key and value calculated in current step to memory manager
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
@ -387,19 +427,16 @@ class LlamaInferenceForwards:
infer_state.context_mem_index,
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states)
llama_context_attn_fwd(
llama_triton_context_attention(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state,
num_key_value_groups=self.num_key_value_groups,
)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
@ -422,45 +459,31 @@ class LlamaInferenceForwards:
infer_state.cache_manager,
)
# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)
if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states)
llama_triton_token_attention(
query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
)
else:
self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
attn_output = flash_attn_with_kvcache(
q=query_states,
k_cache=copy_cache_k,
v_cache=copy_cache_v,
softmax_scale=1 / math.sqrt(self.head_dim),
causal=True,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
# print(f"rank:{torch.distributed.get_rank()}, {attn_output}")
attn_output = self.o_proj(attn_output)
# return past_key_value as None
return attn_output, None, None
def get_llama_vllm_rmsnorm_forward():
if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
return _vllm_rmsnorm_forward
else:
return None

View File

@ -17,7 +17,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling._utils import init_to_get_rotary
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
from ..modeling.llama import LlamaInferenceForwards
try:
from colossalai.kernel.triton import rmsnorm_forward
@ -120,9 +120,6 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
else:
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
infer_forward = get_llama_vllm_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}

View File

@ -1,3 +1,3 @@
from .engine import PPInferEngine
from .microbatch_manager import MicroBatchManager
__all__ = ["PPInferEngine"]
__all__ = ["MicroBatchManager"]

View File

@ -33,10 +33,9 @@ class MicroBatchDescription:
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
new_length: int,
) -> None:
self.mb_length = inputs_dict["input_ids"].shape[-1]
self.target_length = self.mb_length + new_length
self.target_length = self.mb_length + max_output_len
self.infer_state = BatchInferState.init_from_batch(
batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager
)
@ -77,7 +76,6 @@ class HeadMicroBatchDescription(MicroBatchDescription):
Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_length (int): the new length of the input sequence.
"""
@ -87,9 +85,8 @@ class HeadMicroBatchDescription(MicroBatchDescription):
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
new_length: int,
) -> None:
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length)
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
assert inputs_dict is not None
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
self.input_ids = inputs_dict["input_ids"]
@ -139,9 +136,8 @@ class BodyMicroBatchDescription(MicroBatchDescription):
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
new_length: int,
) -> None:
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length)
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
@property
def cur_length(self):
@ -158,7 +154,6 @@ class MicroBatchManager:
Args:
stage (int): stage id of current stage.
new_length (int): the new length of the input sequence.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
@ -167,7 +162,6 @@ class MicroBatchManager:
def __init__(
self,
stage: int,
new_length: int,
micro_batch_size: int,
micro_batch_buffer_size: int,
max_input_len: int,
@ -175,7 +169,6 @@ class MicroBatchManager:
cache_manager_list: MemoryManager,
):
self.stage = stage
self.new_length = new_length
self.micro_batch_size = micro_batch_size
self.buffer_size = micro_batch_buffer_size
self.max_input_len = max_input_len
@ -188,11 +181,11 @@ class MicroBatchManager:
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
)
else:
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
)
def step(self, new_token: torch.Tensor = None):

View File

@ -1,10 +1,9 @@
from typing import List, Optional, Tuple
import math
import copy
from typing import List, Optional, Tuple
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
@ -16,7 +15,9 @@ try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
HAS_LIGHTLLM_KERNEL = True
@ -26,6 +27,7 @@ except:
try:
from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
@ -50,7 +52,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1):
def llama_triton_context_attention(
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
):
if num_key_value_groups == 1:
if HAS_LIGHTLLM_KERNEL is False:
llama_context_attn_fwd(
@ -87,6 +92,7 @@ def llama_triton_context_attention(query_states, key_states, value_states, attn_
infer_state.max_len_in_batch,
)
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
if num_key_value_groups == 1:
@ -265,8 +271,7 @@ class LlamaInferenceForwards:
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@staticmethod
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
@ -309,7 +314,6 @@ class LlamaInferenceForwards:
outputs += (present_key_value,)
return outputs
@staticmethod
def llama_flash_attn_kvcache_forward(
@ -358,8 +362,15 @@ class LlamaInferenceForwards:
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states)
llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)
llama_triton_context_attention(
query_states,
key_states,
value_states,
attn_output,
infer_state,
num_key_value_groups=self.num_key_value_groups,
)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
@ -381,26 +392,28 @@ class LlamaInferenceForwards:
infer_state.decode_mem_index,
infer_state.cache_manager,
)
HAS_LIGHTLLM_KERNEL = False
if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states)
llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)
llama_triton_token_attention(
query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
)
else:
heads_per_group = self.num_heads // self.num_key_value_heads
self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
attn_output = flash_attn_with_kvcache(q = query_states,
k_cache = copy_cache_k,
v_cache = copy_cache_v,
softmax_scale = 1/ math.sqrt(self.head_dim),
causal = True)
attn_output = flash_attn_with_kvcache(
q=query_states,
k_cache=copy_cache_k,
v_cache=copy_cache_v,
softmax_scale=1 / math.sqrt(self.head_dim),
causal=True,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
@ -408,4 +421,3 @@ class LlamaInferenceForwards:
# return past_key_value as None
return attn_output, None, None

View File

@ -5,8 +5,7 @@ import transformers
from packaging import version
import colossalai
from colossalai.inference.pipeline import PPInferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
@ -26,27 +25,43 @@ for k, v in inputs.items():
inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4))
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
)
)
engine = PPInferEngine(
engine = CaiInferEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,
model_policy=LlamaModelInferPolicy(),
new_length=new_length,
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
)
output = engine.inference(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
@parameterize("tp_size", [1])
@parameterize("pp_size", [2])
@parameterize("new_length", [4, 8, 16])
@parameterize("micro_batch_size", [1, 4])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
pipeline_inference_test(pp_size, new_length, micro_batch_size)
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@ -55,12 +70,18 @@ def check_pipeline_inference(rank, world_size, port):
run_pipeline_inference_test()
def check_tp_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=2)
spawn(check_tp_pipeline_inference, nprocs=4)
if __name__ == "__main__":