[Pipeline Inference] Merge pp with tp (#4993)

* refactor pipeline into new CaiInferEngine

* updata llama modeling forward

* merge tp with pp

* update docstring

* optimize test workflow and example

* fix typo

* add assert and todo
pull/4998/head
Bin Jia 2023-11-01 12:46:21 +08:00 committed by GitHub
parent 335cb105e2
commit b6696beb04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 268 additions and 203 deletions

View File

@ -1,4 +1,4 @@
from .pipeline import PPInferEngine from .hybridengine import CaiInferEngine
from .hybridengine.polices import LlamaModelInferPolicy
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]
__all__ = ['PPInferEngine']

View File

@ -0,0 +1,3 @@
from .engine import CaiInferEngine
__all__ = ["CaiInferEngine"]

View File

@ -1,4 +1,5 @@
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding from transformers.tokenization_utils_base import BatchEncoding
@ -8,23 +9,27 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from ..pipeline.microbatch_manager import MicroBatchManager
from ..tensor_parallel.kvcache_manager import MemoryManager from ..tensor_parallel.kvcache_manager import MemoryManager
from .microbatch_manager import MicroBatchManager
PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
]
class PPInferEngine: class CaiInferEngine:
""" """
PPInferEngine is a class that handles the pipeline parallel inference. CaiInferEngine is a class that handles the pipeline parallel inference.
Args: Args:
pp_size (int): the number of pipeline stages. tp_size (int): the size of tensor parallelism.
pp_model (`nn.Module`): the model already in pipeline parallelism style. pp_size (int): the size of pipeline parallelism.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
micro_batch_size (int): the micro batch size. micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
new_length (int): the new length of the input sequence.
early_stopping (bool): whether to stop early.
max_batch_size (int): the maximum batch size. max_batch_size (int): the maximum batch size.
max_input_len (int): the maximum input length. max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length. max_output_len (int): the maximum output length.
@ -32,7 +37,7 @@ class PPInferEngine:
Example: Example:
```python ```python
from colossalai.inference import PPInferEngine from colossalai.inference import InferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer from transformers import LlamaForCausalLM, LlamaTokenizer
@ -42,7 +47,7 @@ class PPInferEngine:
model = LlamaForCausalLM.from_pretrained("your_path_to_model") model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages # assume the model is infered with 2 pipeline stages
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8) inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
input = ["Introduce a landmark in China ","Introduce a landmark in China "] input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt') data = tokenizer(input, return_tensors='pt')
@ -54,12 +59,11 @@ class PPInferEngine:
def __init__( def __init__(
self, self,
pp_size: int, tp_size: int = 1,
pp_size: int = 1,
dtype: str = "fp16", dtype: str = "fp16",
pp_model: nn.Module = None,
model: nn.Module = None, model: nn.Module = None,
model_policy: Policy = None, model_policy: Policy = None,
new_length: int = 32,
micro_batch_size: int = 1, micro_batch_size: int = 1,
micro_batch_buffer_size: int = None, micro_batch_buffer_size: int = None,
max_batch_size: int = 4, max_batch_size: int = 4,
@ -71,12 +75,21 @@ class PPInferEngine:
do_sample: bool = False, do_sample: bool = False,
num_beams: int = 1, num_beams: int = 1,
) -> None: ) -> None:
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided." assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
assert (
tp_size * pp_size == dist.get_world_size()
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert model and model_policy, "Model with model_policy should be provided."
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
max_output_len = max(max_output_len, max_input_len + new_length) assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
# TODO: support only tensor parallel inference
assert pp_size > 1, "Not support only tensor parallel inference."
self.pp_size = pp_size self.pp_size = pp_size
self.tp_size = tp_size
if dtype == "fp16": if dtype == "fp16":
self.dtype = torch.float16 self.dtype = torch.float16
model.half() model.half()
@ -85,24 +98,29 @@ class PPInferEngine:
model.to(torch.bfloat16) model.to(torch.bfloat16)
else: else:
self.dtype = torch.float32 self.dtype = torch.float32
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) # Init pg mesh
self.model = pp_model or self._shardformer(model, model_policy) pg_mesh = ProcessGroupMesh(pp_size, tp_size)
self.cache_manager_list = [
self._init_manager(max_batch_size, max_input_len, max_output_len) stage_manager = None
for _ in range(micro_batch_buffer_size or pp_size) if pp_size > 1:
] stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
self.mb_manager = MicroBatchManager( self.cache_manager_list = [
self.stage_manager.stage, self._init_manager(model, max_batch_size, max_input_len, max_output_len)
new_length, for _ in range(micro_batch_buffer_size or pp_size)
micro_batch_size, ]
micro_batch_buffer_size or pp_size, self.mb_manager = MicroBatchManager(
max_input_len, stage_manager.stage,
max_output_len, micro_batch_size,
self.cache_manager_list, micro_batch_buffer_size or pp_size,
) max_input_len,
self.verbose = verbose max_output_len,
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
def inference(self, input_list): def inference(self, input_list):
""" """
@ -124,10 +142,10 @@ class PPInferEngine:
else: else:
return out return out
def _shardformer(self, model, model_policy): def _shardformer(self, model, model_policy, stage_manager, tp_group):
shardconfig = ShardConfig( shardconfig = ShardConfig(
tensor_parallel_process_group=None, tensor_parallel_process_group=tp_group,
pipeline_stage_manager=self.stage_manager, pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=False, enable_tensor_parallelism=False,
enable_fused_normalization=False, enable_fused_normalization=False,
enable_all_optimization=False, enable_all_optimization=False,
@ -139,14 +157,12 @@ class PPInferEngine:
shard_model, _ = shardformer.optimize(model, model_policy) shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda() return shard_model.cuda()
def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len) max_total_token_num = max_batch_size * (max_input_len + max_output_len)
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = self.model.config.num_attention_heads head_num = model.config.num_attention_heads
num_hidden_layers = ( num_hidden_layers = (
self.model.config.num_hidden_layers model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
if hasattr(self.model.config, "num_hidden_layers")
else self.model.config.num_layers
) )
layer_num = num_hidden_layers // self.pp_size layer_num = num_hidden_layers // self.pp_size

View File

@ -1,37 +1,25 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
from transformers.utils import logging from transformers.utils import logging
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import copy_kv_to_mem_cache from ._utils import copy_kv_to_mem_cache
try: try:
from vllm import layernorm_ops, pos_encoding_ops from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
rms_norm = layernorm_ops.rms_norm )
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
HAS_VLLM_KERNERL = True context_attention_fwd as lightllm_context_attention_fwd,
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print(
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
) )
HAS_VLLM_KERNERL = False
try:
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
HAS_LIGHTLLM_KERNEL = True HAS_LIGHTLLM_KERNEL = True
@ -39,6 +27,14 @@ except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False HAS_LIGHTLLM_KERNEL = False
try:
from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
@ -59,6 +55,75 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return q_embed, k_embed return q_embed, k_embed
def llama_triton_context_attention(
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
):
if num_key_value_groups == 1:
if HAS_LIGHTLLM_KERNEL is False:
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
lightllm_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
lightllm_llama2_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
if num_key_value_groups == 1:
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
Llama2TokenAttentionForwards.token_attn(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)
class LlamaInferenceForwards: class LlamaInferenceForwards:
""" """
This class holds forwards for llama inference. This class holds forwards for llama inference.
@ -144,13 +209,9 @@ class LlamaInferenceForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
): ):
# batch_size = input_ids.shape[0] # input_ids.shape[0]
# print(f"[Before] rank:{torch.distributed.get_rank()}\n->{infer_state}")
# infer_state = self.infer_state
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
@ -172,12 +233,10 @@ class LlamaInferenceForwards:
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
device = hidden_states.device device = hidden_states.device
seq_length_with_past = seq_length if infer_state.is_context_stage:
past_key_values_length = 0 past_key_values_length = 0
else:
if infer_state.is_context_stage is False: past_key_values_length = infer_state.max_len_in_batch - 1
past_key_values_length = infer_state.cache_manager.past_key_values_length
seq_length_with_past = seq_length_with_past + past_key_values_length
# NOTE: differentiate with prefill stage # NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage # block_loc require different value-assigning method for two different stage
@ -197,26 +256,19 @@ class LlamaInferenceForwards:
infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2] infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else: else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
infer_state.decode_is_contiguous = False infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size) alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
if position_ids is None: if position_ids is None:
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
) )
position_ids = position_ids.unsqueeze(0) position_ids = position_ids.repeat(batch_size, 1)
new_shape = [1] * position_ids.dim() position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
new_shape[0] = batch_size
position_ids = position_ids.repeat(*new_shape).view(-1, seq_length)
else: else:
position_ids = position_ids.view(-1, seq_length).long() position_ids = position_ids.view(-1, seq_length).long()
@ -227,15 +279,17 @@ class LlamaInferenceForwards:
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1 position_ids.view(-1).shape[0], -1
) )
else: else:
seq_len = infer_state.seq_len seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
# embed positions # embed positions
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones( attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device
) )
attention_mask = self._prepare_decoder_attention_mask( attention_mask = self._prepare_decoder_attention_mask(
@ -243,10 +297,6 @@ class LlamaInferenceForwards:
) )
# decoder layers # decoder layers
() if output_hidden_states else None
() if output_attentions else None
next_decoder_cache = () if use_cache else None
infer_state.decode_layer_id = 0 infer_state.decode_layer_id = 0
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
@ -268,19 +318,15 @@ class LlamaInferenceForwards:
infer_state.decode_layer_id += 1 infer_state.decode_layer_id += 1
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
next_cache = next_decoder_cache if use_cache else None
# update indices # update indices
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1 infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
# TODO: fix this to necessary return
# if not return_dict: # if not return_dict:
# return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
@ -290,8 +336,7 @@ class LlamaInferenceForwards:
# hidden_states=all_hidden_states, # hidden_states=all_hidden_states,
# attentions=all_self_attns, # attentions=all_self_attns,
# ) # )
# print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}") return {"hidden_states": hidden_states}
return {"hidden_states": hidden_states, "past_key_values": next_cache}
@staticmethod @staticmethod
def llama_decoder_layer_forward( def llama_decoder_layer_forward(
@ -307,7 +352,6 @@ class LlamaInferenceForwards:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
@ -357,28 +401,24 @@ class LlamaInferenceForwards:
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
# NOTE might want to revise # NOTE might want to revise
# need some way to record the length of past key values cache # need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now # since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin cos, sin = infer_state.position_cos, infer_state.position_sin
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim) query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim) value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)
if infer_state.is_context_stage: if infer_state.is_context_stage:
# print(f"rank:{torch.distributed.get_rank()}, {infer_state}")
# first token generation # first token generation
# copy key and value calculated in current step to memory manager # copy key and value calculated in current step to memory manager
copy_kv_to_mem_cache( copy_kv_to_mem_cache(
infer_state.decode_layer_id, infer_state.decode_layer_id,
@ -387,19 +427,16 @@ class LlamaInferenceForwards:
infer_state.context_mem_index, infer_state.context_mem_index,
infer_state.cache_manager, infer_state.cache_manager,
) )
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
llama_context_attn_fwd( llama_triton_context_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_output, attn_output,
infer_state.start_loc, infer_state,
infer_state.seq_len, num_key_value_groups=self.num_key_value_groups,
infer_state.cache_manager.past_key_values_length,
) )
else: else:
if infer_state.decode_is_contiguous: if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
@ -422,45 +459,31 @@ class LlamaInferenceForwards:
infer_state.cache_manager, infer_state.cache_manager,
) )
# second token and follows if HAS_LIGHTLLM_KERNEL:
# kv = torch.stack((key_states, value_states), dim=2) attn_output = torch.empty_like(query_states)
# (batch_size, seqlen, nheads, headdim) llama_triton_token_attention(
attn_output = torch.empty_like(query_states) query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
)
else:
self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
token_attention_fwd( query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
query_states, copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output, attn_output = flash_attn_with_kvcache(
infer_state.block_loc, q=query_states,
infer_state.start_loc, k_cache=copy_cache_k,
infer_state.seq_len, v_cache=copy_cache_v,
infer_state.cache_manager.past_key_values_length, softmax_scale=1 / math.sqrt(self.head_dim),
) causal=True,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)
# print(f"rank:{torch.distributed.get_rank()}, {attn_output}")
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
# return past_key_value as None # return past_key_value as None
return attn_output, None, None return attn_output, None, None
def get_llama_vllm_rmsnorm_forward():
if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
return _vllm_rmsnorm_forward
else:
return None

View File

@ -17,7 +17,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling._utils import init_to_get_rotary from ..modeling._utils import init_to_get_rotary
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward from ..modeling.llama import LlamaInferenceForwards
try: try:
from colossalai.kernel.triton import rmsnorm_forward from colossalai.kernel.triton import rmsnorm_forward
@ -120,9 +120,6 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
infer_forward = None infer_forward = None
if HAS_TRITON_RMSNORM: if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward() infer_forward = get_triton_rmsnorm_forward()
else:
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
infer_forward = get_llama_vllm_rmsnorm_forward()
if infer_forward is not None: if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)} method_replacement = {"forward": partial(infer_forward)}

View File

@ -1,3 +1,3 @@
from .engine import PPInferEngine from .microbatch_manager import MicroBatchManager
__all__ = ["PPInferEngine"] __all__ = ["MicroBatchManager"]

View File

@ -33,10 +33,9 @@ class MicroBatchDescription:
max_input_len: int, max_input_len: int,
max_output_len: int, max_output_len: int,
cache_manager: MemoryManager, cache_manager: MemoryManager,
new_length: int,
) -> None: ) -> None:
self.mb_length = inputs_dict["input_ids"].shape[-1] self.mb_length = inputs_dict["input_ids"].shape[-1]
self.target_length = self.mb_length + new_length self.target_length = self.mb_length + max_output_len
self.infer_state = BatchInferState.init_from_batch( self.infer_state = BatchInferState.init_from_batch(
batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager
) )
@ -77,7 +76,6 @@ class HeadMicroBatchDescription(MicroBatchDescription):
Args: Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_length (int): the new length of the input sequence.
""" """
@ -87,9 +85,8 @@ class HeadMicroBatchDescription(MicroBatchDescription):
max_input_len: int, max_input_len: int,
max_output_len: int, max_output_len: int,
cache_manager: MemoryManager, cache_manager: MemoryManager,
new_length: int,
) -> None: ) -> None:
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
assert inputs_dict is not None assert inputs_dict is not None
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
self.input_ids = inputs_dict["input_ids"] self.input_ids = inputs_dict["input_ids"]
@ -139,9 +136,8 @@ class BodyMicroBatchDescription(MicroBatchDescription):
max_input_len: int, max_input_len: int,
max_output_len: int, max_output_len: int,
cache_manager: MemoryManager, cache_manager: MemoryManager,
new_length: int,
) -> None: ) -> None:
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
@property @property
def cur_length(self): def cur_length(self):
@ -158,7 +154,6 @@ class MicroBatchManager:
Args: Args:
stage (int): stage id of current stage. stage (int): stage id of current stage.
new_length (int): the new length of the input sequence.
micro_batch_size (int): the micro batch size. micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
@ -167,7 +162,6 @@ class MicroBatchManager:
def __init__( def __init__(
self, self,
stage: int, stage: int,
new_length: int,
micro_batch_size: int, micro_batch_size: int,
micro_batch_buffer_size: int, micro_batch_buffer_size: int,
max_input_len: int, max_input_len: int,
@ -175,7 +169,6 @@ class MicroBatchManager:
cache_manager_list: MemoryManager, cache_manager_list: MemoryManager,
): ):
self.stage = stage self.stage = stage
self.new_length = new_length
self.micro_batch_size = micro_batch_size self.micro_batch_size = micro_batch_size
self.buffer_size = micro_batch_buffer_size self.buffer_size = micro_batch_buffer_size
self.max_input_len = max_input_len self.max_input_len = max_input_len
@ -188,11 +181,11 @@ class MicroBatchManager:
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
if self.stage == 0: if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
) )
else: else:
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
) )
def step(self, new_token: torch.Tensor = None): def step(self, new_token: torch.Tensor = None):

View File

@ -1,10 +1,9 @@
from typing import List, Optional, Tuple
import math import math
import copy from typing import List, Optional, Tuple
import torch import torch
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
@ -16,7 +15,9 @@ try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd, context_attention_fwd as lightllm_llama2_context_attention_fwd,
) )
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
HAS_LIGHTLLM_KERNEL = True HAS_LIGHTLLM_KERNEL = True
@ -26,6 +27,7 @@ except:
try: try:
from flash_attn import flash_attn_with_kvcache from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True HAS_FLASH_KERNEL = True
except: except:
HAS_FLASH_KERNEL = False HAS_FLASH_KERNEL = False
@ -50,7 +52,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1):
def llama_triton_context_attention(
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
):
if num_key_value_groups == 1: if num_key_value_groups == 1:
if HAS_LIGHTLLM_KERNEL is False: if HAS_LIGHTLLM_KERNEL is False:
llama_context_attn_fwd( llama_context_attn_fwd(
@ -87,6 +92,7 @@ def llama_triton_context_attention(query_states, key_states, value_states, attn_
infer_state.max_len_in_batch, infer_state.max_len_in_batch,
) )
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
if num_key_value_groups == 1: if num_key_value_groups == 1:
@ -265,8 +271,7 @@ class LlamaInferenceForwards:
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
@staticmethod @staticmethod
def llama_decoder_layer_forward( def llama_decoder_layer_forward(
self: LlamaDecoderLayer, self: LlamaDecoderLayer,
@ -309,7 +314,6 @@ class LlamaInferenceForwards:
outputs += (present_key_value,) outputs += (present_key_value,)
return outputs return outputs
@staticmethod @staticmethod
def llama_flash_attn_kvcache_forward( def llama_flash_attn_kvcache_forward(
@ -358,8 +362,15 @@ class LlamaInferenceForwards:
infer_state.cache_manager, infer_state.cache_manager,
) )
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) llama_triton_context_attention(
query_states,
key_states,
value_states,
attn_output,
infer_state,
num_key_value_groups=self.num_key_value_groups,
)
else: else:
if infer_state.decode_is_contiguous: if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
@ -381,26 +392,28 @@ class LlamaInferenceForwards:
infer_state.decode_mem_index, infer_state.decode_mem_index,
infer_state.cache_manager, infer_state.cache_manager,
) )
HAS_LIGHTLLM_KERNEL = False
if HAS_LIGHTLLM_KERNEL: if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) llama_triton_token_attention(
query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
)
else: else:
heads_per_group = self.num_heads // self.num_key_value_heads self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
attn_output = flash_attn_with_kvcache(q = query_states,
k_cache = copy_cache_k,
v_cache = copy_cache_v,
softmax_scale = 1/ math.sqrt(self.head_dim),
causal = True)
attn_output = flash_attn_with_kvcache(
q=query_states,
k_cache=copy_cache_k,
v_cache=copy_cache_v,
softmax_scale=1 / math.sqrt(self.head_dim),
causal=True,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)
@ -408,4 +421,3 @@ class LlamaInferenceForwards:
# return past_key_value as None # return past_key_value as None
return attn_output, None, None return attn_output, None, None

View File

@ -5,8 +5,7 @@ import transformers
from packaging import version from packaging import version
import colossalai import colossalai
from colossalai.inference.pipeline import PPInferEngine from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
@ -26,27 +25,43 @@ for k, v in inputs.items():
inputs[k] = v.to("cuda").repeat(*new_shape) inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(pp_size, new_length, micro_batch_size): def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4)) model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
)
)
engine = PPInferEngine( engine = CaiInferEngine(
tp_size=tp_size,
pp_size=pp_size, pp_size=pp_size,
model=model, model=model,
model_policy=LlamaModelInferPolicy(), model_policy=LlamaModelInferPolicy(),
new_length=new_length, max_output_len=max_output_len,
micro_batch_size=micro_batch_size, micro_batch_size=micro_batch_size,
) )
output = engine.inference(inputs) output = engine.inference(inputs)
if dist.get_rank() == 0: if dist.get_rank() == 0:
assert len(output[0]) == new_length, f"{len(output)}, {new_length}" assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
@parameterize("tp_size", [1])
@parameterize("pp_size", [2]) @parameterize("pp_size", [2])
@parameterize("new_length", [4, 8, 16]) @parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1, 4]) @parameterize("micro_batch_size", [1])
@clear_cache_before_run() @clear_cache_before_run()
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(pp_size, new_length, micro_batch_size) pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -55,12 +70,18 @@ def check_pipeline_inference(rank, world_size, port):
run_pipeline_inference_test() run_pipeline_inference_test()
def check_tp_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run() @clear_cache_before_run()
def test_pipeline_inference(): def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=2) spawn(check_pipeline_inference, nprocs=2)
spawn(check_tp_pipeline_inference, nprocs=4)
if __name__ == "__main__": if __name__ == "__main__":