mirror of https://github.com/hpcaitech/ColossalAI
[Pipeline Inference] Merge pp with tp (#4993)
* refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todopull/4998/head
parent
335cb105e2
commit
b6696beb04
|
@ -1,4 +1,4 @@
|
|||
from .pipeline import PPInferEngine
|
||||
from .hybridengine import CaiInferEngine
|
||||
from .hybridengine.polices import LlamaModelInferPolicy
|
||||
|
||||
|
||||
__all__ = ['PPInferEngine']
|
||||
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .engine import CaiInferEngine
|
||||
|
||||
__all__ = ["CaiInferEngine"]
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
|
@ -8,23 +9,27 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from ..pipeline.microbatch_manager import MicroBatchManager
|
||||
from ..tensor_parallel.kvcache_manager import MemoryManager
|
||||
from .microbatch_manager import MicroBatchManager
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = [
|
||||
"LlamaForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
class PPInferEngine:
|
||||
class CaiInferEngine:
|
||||
"""
|
||||
PPInferEngine is a class that handles the pipeline parallel inference.
|
||||
CaiInferEngine is a class that handles the pipeline parallel inference.
|
||||
|
||||
Args:
|
||||
pp_size (int): the number of pipeline stages.
|
||||
pp_model (`nn.Module`): the model already in pipeline parallelism style.
|
||||
tp_size (int): the size of tensor parallelism.
|
||||
pp_size (int): the size of pipeline parallelism.
|
||||
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
|
||||
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
|
||||
micro_batch_size (int): the micro batch size.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
new_length (int): the new length of the input sequence.
|
||||
early_stopping (bool): whether to stop early.
|
||||
max_batch_size (int): the maximum batch size.
|
||||
max_input_len (int): the maximum input length.
|
||||
max_output_len (int): the maximum output length.
|
||||
|
@ -32,7 +37,7 @@ class PPInferEngine:
|
|||
Example:
|
||||
|
||||
```python
|
||||
from colossalai.inference import PPInferEngine
|
||||
from colossalai.inference import InferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
@ -42,7 +47,7 @@ class PPInferEngine:
|
|||
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
|
||||
# assume the model is infered with 2 pipeline stages
|
||||
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
|
||||
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
|
||||
|
||||
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
|
||||
data = tokenizer(input, return_tensors='pt')
|
||||
|
@ -54,12 +59,11 @@ class PPInferEngine:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
pp_size: int,
|
||||
tp_size: int = 1,
|
||||
pp_size: int = 1,
|
||||
dtype: str = "fp16",
|
||||
pp_model: nn.Module = None,
|
||||
model: nn.Module = None,
|
||||
model_policy: Policy = None,
|
||||
new_length: int = 32,
|
||||
micro_batch_size: int = 1,
|
||||
micro_batch_buffer_size: int = None,
|
||||
max_batch_size: int = 4,
|
||||
|
@ -71,12 +75,21 @@ class PPInferEngine:
|
|||
do_sample: bool = False,
|
||||
num_beams: int = 1,
|
||||
) -> None:
|
||||
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
|
||||
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
|
||||
assert (
|
||||
tp_size * pp_size == dist.get_world_size()
|
||||
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
assert model and model_policy, "Model with model_policy should be provided."
|
||||
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
||||
|
||||
max_output_len = max(max_output_len, max_input_len + new_length)
|
||||
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
|
||||
|
||||
# TODO: support only tensor parallel inference
|
||||
assert pp_size > 1, "Not support only tensor parallel inference."
|
||||
self.pp_size = pp_size
|
||||
self.tp_size = tp_size
|
||||
|
||||
if dtype == "fp16":
|
||||
self.dtype = torch.float16
|
||||
model.half()
|
||||
|
@ -85,24 +98,29 @@ class PPInferEngine:
|
|||
model.to(torch.bfloat16)
|
||||
else:
|
||||
self.dtype = torch.float32
|
||||
self.pg_mesh = ProcessGroupMesh(pp_size)
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
|
||||
self.model = pp_model or self._shardformer(model, model_policy)
|
||||
self.cache_manager_list = [
|
||||
self._init_manager(max_batch_size, max_input_len, max_output_len)
|
||||
for _ in range(micro_batch_buffer_size or pp_size)
|
||||
]
|
||||
self.mb_manager = MicroBatchManager(
|
||||
self.stage_manager.stage,
|
||||
new_length,
|
||||
micro_batch_size,
|
||||
micro_batch_buffer_size or pp_size,
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
self.cache_manager_list,
|
||||
)
|
||||
self.verbose = verbose
|
||||
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
|
||||
|
||||
# Init pg mesh
|
||||
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
|
||||
|
||||
stage_manager = None
|
||||
if pp_size > 1:
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
|
||||
self.cache_manager_list = [
|
||||
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
|
||||
for _ in range(micro_batch_buffer_size or pp_size)
|
||||
]
|
||||
self.mb_manager = MicroBatchManager(
|
||||
stage_manager.stage,
|
||||
micro_batch_size,
|
||||
micro_batch_buffer_size or pp_size,
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
self.cache_manager_list,
|
||||
)
|
||||
self.verbose = verbose
|
||||
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
||||
|
||||
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
|
||||
|
||||
def inference(self, input_list):
|
||||
"""
|
||||
|
@ -124,10 +142,10 @@ class PPInferEngine:
|
|||
else:
|
||||
return out
|
||||
|
||||
def _shardformer(self, model, model_policy):
|
||||
def _shardformer(self, model, model_policy, stage_manager, tp_group):
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=None,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=False,
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
|
@ -139,14 +157,12 @@ class PPInferEngine:
|
|||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model.cuda()
|
||||
|
||||
def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
|
||||
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
|
||||
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
|
||||
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
|
||||
head_num = self.model.config.num_attention_heads
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
head_num = model.config.num_attention_heads
|
||||
num_hidden_layers = (
|
||||
self.model.config.num_hidden_layers
|
||||
if hasattr(self.model.config, "num_hidden_layers")
|
||||
else self.model.config.num_layers
|
||||
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
|
||||
)
|
||||
layer_num = num_hidden_layers // self.pp_size
|
||||
|
|
@ -1,37 +1,25 @@
|
|||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import copy_kv_to_mem_cache
|
||||
|
||||
try:
|
||||
from vllm import layernorm_ops, pos_encoding_ops
|
||||
|
||||
rms_norm = layernorm_ops.rms_norm
|
||||
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
|
||||
HAS_VLLM_KERNERL = True
|
||||
except:
|
||||
print("fall back to original rotary_embedding_neox of huggingface")
|
||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
||||
print(
|
||||
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
|
||||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_context_attention_fwd,
|
||||
)
|
||||
HAS_VLLM_KERNERL = False
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
|
@ -39,6 +27,14 @@ except:
|
|||
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
|
||||
HAS_FLASH_KERNEL = True
|
||||
except:
|
||||
HAS_FLASH_KERNEL = False
|
||||
print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
|
@ -59,6 +55,75 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|||
return q_embed, k_embed
|
||||
|
||||
|
||||
def llama_triton_context_attention(
|
||||
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
|
||||
):
|
||||
if num_key_value_groups == 1:
|
||||
if HAS_LIGHTLLM_KERNEL is False:
|
||||
llama_context_attn_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
else:
|
||||
lightllm_context_attention_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
else:
|
||||
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
|
||||
lightllm_llama2_context_attention_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
|
||||
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
|
||||
if num_key_value_groups == 1:
|
||||
token_attention_fwd(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
else:
|
||||
Llama2TokenAttentionForwards.token_attn(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
infer_state.other_kv_index,
|
||||
)
|
||||
|
||||
|
||||
class LlamaInferenceForwards:
|
||||
"""
|
||||
This class holds forwards for llama inference.
|
||||
|
@ -144,13 +209,9 @@ class LlamaInferenceForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
# batch_size = input_ids.shape[0] # input_ids.shape[0]
|
||||
# print(f"[Before] rank:{torch.distributed.get_rank()}\n->{infer_state}")
|
||||
|
||||
# infer_state = self.infer_state
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
|
@ -172,12 +233,10 @@ class LlamaInferenceForwards:
|
|||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if infer_state.is_context_stage is False:
|
||||
past_key_values_length = infer_state.cache_manager.past_key_values_length
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
if infer_state.is_context_stage:
|
||||
past_key_values_length = 0
|
||||
else:
|
||||
past_key_values_length = infer_state.max_len_in_batch - 1
|
||||
|
||||
# NOTE: differentiate with prefill stage
|
||||
# block_loc require different value-assigning method for two different stage
|
||||
|
@ -197,26 +256,19 @@ class LlamaInferenceForwards:
|
|||
infer_state.decode_mem_index = alloc_mem[0]
|
||||
infer_state.decode_mem_start = alloc_mem[1]
|
||||
infer_state.decode_mem_end = alloc_mem[2]
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
|
||||
else:
|
||||
print(f" *** Encountered allocation non-contiguous")
|
||||
print(
|
||||
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
|
||||
)
|
||||
infer_state.decode_is_contiguous = False
|
||||
alloc_mem = infer_state.cache_manager.alloc(batch_size)
|
||||
infer_state.decode_mem_index = alloc_mem
|
||||
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
|
||||
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
new_shape = [1] * position_ids.dim()
|
||||
new_shape[0] = batch_size
|
||||
position_ids = position_ids.repeat(*new_shape).view(-1, seq_length)
|
||||
position_ids = position_ids.repeat(batch_size, 1)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
|
@ -227,15 +279,17 @@ class LlamaInferenceForwards:
|
|||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
|
||||
else:
|
||||
seq_len = infer_state.seq_len
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||
infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
(batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
|
@ -243,10 +297,6 @@ class LlamaInferenceForwards:
|
|||
)
|
||||
|
||||
# decoder layers
|
||||
() if output_hidden_states else None
|
||||
() if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
infer_state.decode_layer_id = 0
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
@ -268,19 +318,15 @@ class LlamaInferenceForwards:
|
|||
infer_state.decode_layer_id += 1
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
# update indices
|
||||
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.seq_len += 1
|
||||
infer_state.max_len_in_batch += 1
|
||||
|
||||
# TODO: fix this to necessary return
|
||||
# if not return_dict:
|
||||
# return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
|
@ -290,8 +336,7 @@ class LlamaInferenceForwards:
|
|||
# hidden_states=all_hidden_states,
|
||||
# attentions=all_self_attns,
|
||||
# )
|
||||
# print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}")
|
||||
return {"hidden_states": hidden_states, "past_key_values": next_cache}
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def llama_decoder_layer_forward(
|
||||
|
@ -307,7 +352,6 @@ class LlamaInferenceForwards:
|
|||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
|
@ -357,28 +401,24 @@ class LlamaInferenceForwards:
|
|||
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# NOTE might want to revise
|
||||
# need some way to record the length of past key values cache
|
||||
# since we won't return past_key_value_cache right now
|
||||
if infer_state.decode_layer_id == 0: # once per model.forward
|
||||
infer_state.cache_manager.past_key_values_length += q_len # seq_len
|
||||
|
||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||
|
||||
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||
llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
|
||||
|
||||
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
# print(f"rank:{torch.distributed.get_rank()}, {infer_state}")
|
||||
# first token generation
|
||||
|
||||
# copy key and value calculated in current step to memory manager
|
||||
copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
|
@ -387,19 +427,16 @@ class LlamaInferenceForwards:
|
|||
infer_state.context_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
llama_context_attn_fwd(
|
||||
llama_triton_context_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
infer_state,
|
||||
num_key_value_groups=self.num_key_value_groups,
|
||||
)
|
||||
|
||||
else:
|
||||
if infer_state.decode_is_contiguous:
|
||||
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||
|
@ -422,45 +459,31 @@ class LlamaInferenceForwards:
|
|||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
# second token and follows
|
||||
# kv = torch.stack((key_states, value_states), dim=2)
|
||||
# (batch_size, seqlen, nheads, headdim)
|
||||
attn_output = torch.empty_like(query_states)
|
||||
if HAS_LIGHTLLM_KERNEL:
|
||||
attn_output = torch.empty_like(query_states)
|
||||
llama_triton_token_attention(
|
||||
query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
|
||||
)
|
||||
else:
|
||||
self.num_heads // self.num_key_value_heads
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
|
||||
|
||||
token_attention_fwd(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
)
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
|
||||
copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
|
||||
copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
attn_output = flash_attn_with_kvcache(
|
||||
q=query_states,
|
||||
k_cache=copy_cache_k,
|
||||
v_cache=copy_cache_v,
|
||||
softmax_scale=1 / math.sqrt(self.head_dim),
|
||||
causal=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
# print(f"rank:{torch.distributed.get_rank()}, {attn_output}")
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
# return past_key_value as None
|
||||
return attn_output, None, None
|
||||
|
||||
|
||||
def get_llama_vllm_rmsnorm_forward():
|
||||
if HAS_VLLM_KERNERL:
|
||||
|
||||
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
x = hidden_states
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(
|
||||
out,
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
return _vllm_rmsnorm_forward
|
||||
else:
|
||||
return None
|
|
@ -17,7 +17,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
|
|||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
from ..modeling._utils import init_to_get_rotary
|
||||
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
||||
from ..modeling.llama import LlamaInferenceForwards
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import rmsnorm_forward
|
||||
|
@ -120,9 +120,6 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
infer_forward = None
|
||||
if HAS_TRITON_RMSNORM:
|
||||
infer_forward = get_triton_rmsnorm_forward()
|
||||
else:
|
||||
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
|
||||
infer_forward = get_llama_vllm_rmsnorm_forward()
|
||||
|
||||
if infer_forward is not None:
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
|
@ -1,3 +1,3 @@
|
|||
from .engine import PPInferEngine
|
||||
from .microbatch_manager import MicroBatchManager
|
||||
|
||||
__all__ = ["PPInferEngine"]
|
||||
__all__ = ["MicroBatchManager"]
|
||||
|
|
|
@ -33,10 +33,9 @@ class MicroBatchDescription:
|
|||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
new_length: int,
|
||||
) -> None:
|
||||
self.mb_length = inputs_dict["input_ids"].shape[-1]
|
||||
self.target_length = self.mb_length + new_length
|
||||
self.target_length = self.mb_length + max_output_len
|
||||
self.infer_state = BatchInferState.init_from_batch(
|
||||
batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager
|
||||
)
|
||||
|
@ -77,7 +76,6 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
|||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
new_length (int): the new length of the input sequence.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -87,9 +85,8 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
|||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
new_length: int,
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length)
|
||||
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
|
||||
assert inputs_dict is not None
|
||||
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
|
||||
self.input_ids = inputs_dict["input_ids"]
|
||||
|
@ -139,9 +136,8 @@ class BodyMicroBatchDescription(MicroBatchDescription):
|
|||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
new_length: int,
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length)
|
||||
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
|
@ -158,7 +154,6 @@ class MicroBatchManager:
|
|||
|
||||
Args:
|
||||
stage (int): stage id of current stage.
|
||||
new_length (int): the new length of the input sequence.
|
||||
micro_batch_size (int): the micro batch size.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
|
||||
|
@ -167,7 +162,6 @@ class MicroBatchManager:
|
|||
def __init__(
|
||||
self,
|
||||
stage: int,
|
||||
new_length: int,
|
||||
micro_batch_size: int,
|
||||
micro_batch_buffer_size: int,
|
||||
max_input_len: int,
|
||||
|
@ -175,7 +169,6 @@ class MicroBatchManager:
|
|||
cache_manager_list: MemoryManager,
|
||||
):
|
||||
self.stage = stage
|
||||
self.new_length = new_length
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.buffer_size = micro_batch_buffer_size
|
||||
self.max_input_len = max_input_len
|
||||
|
@ -188,11 +181,11 @@ class MicroBatchManager:
|
|||
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
|
||||
if self.stage == 0:
|
||||
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
|
||||
)
|
||||
else:
|
||||
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
|
||||
)
|
||||
|
||||
def step(self, new_token: torch.Tensor = None):
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from typing import List, Optional, Tuple
|
||||
import math
|
||||
import copy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
||||
|
@ -16,7 +15,9 @@ try:
|
|||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd
|
||||
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_context_attention_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
|
@ -26,6 +27,7 @@ except:
|
|||
|
||||
try:
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
|
||||
HAS_FLASH_KERNEL = True
|
||||
except:
|
||||
HAS_FLASH_KERNEL = False
|
||||
|
@ -50,7 +52,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1):
|
||||
|
||||
def llama_triton_context_attention(
|
||||
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
|
||||
):
|
||||
if num_key_value_groups == 1:
|
||||
if HAS_LIGHTLLM_KERNEL is False:
|
||||
llama_context_attn_fwd(
|
||||
|
@ -87,6 +92,7 @@ def llama_triton_context_attention(query_states, key_states, value_states, attn_
|
|||
infer_state.max_len_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
|
||||
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
|
||||
if num_key_value_groups == 1:
|
||||
|
@ -265,8 +271,7 @@ class LlamaInferenceForwards:
|
|||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def llama_decoder_layer_forward(
|
||||
self: LlamaDecoderLayer,
|
||||
|
@ -309,7 +314,6 @@ class LlamaInferenceForwards:
|
|||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@staticmethod
|
||||
def llama_flash_attn_kvcache_forward(
|
||||
|
@ -358,8 +362,15 @@ class LlamaInferenceForwards:
|
|||
infer_state.cache_manager,
|
||||
)
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)
|
||||
|
||||
llama_triton_context_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state,
|
||||
num_key_value_groups=self.num_key_value_groups,
|
||||
)
|
||||
else:
|
||||
if infer_state.decode_is_contiguous:
|
||||
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||
|
@ -381,26 +392,28 @@ class LlamaInferenceForwards:
|
|||
infer_state.decode_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
if HAS_LIGHTLLM_KERNEL:
|
||||
attn_output = torch.empty_like(query_states)
|
||||
llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)
|
||||
llama_triton_token_attention(
|
||||
query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
|
||||
)
|
||||
else:
|
||||
heads_per_group = self.num_heads // self.num_key_value_heads
|
||||
self.num_heads // self.num_key_value_heads
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
|
||||
|
||||
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
|
||||
copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
|
||||
copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
attn_output = flash_attn_with_kvcache(q = query_states,
|
||||
k_cache = copy_cache_k,
|
||||
v_cache = copy_cache_v,
|
||||
softmax_scale = 1/ math.sqrt(self.head_dim),
|
||||
causal = True)
|
||||
|
||||
attn_output = flash_attn_with_kvcache(
|
||||
q=query_states,
|
||||
k_cache=copy_cache_k,
|
||||
v_cache=copy_cache_v,
|
||||
softmax_scale=1 / math.sqrt(self.head_dim),
|
||||
causal=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -408,4 +421,3 @@ class LlamaInferenceForwards:
|
|||
|
||||
# return past_key_value as None
|
||||
return attn_output, None, None
|
||||
|
||||
|
|
|
@ -5,8 +5,7 @@ import transformers
|
|||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.pipeline import PPInferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
@ -26,27 +25,43 @@ for k, v in inputs.items():
|
|||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
|
||||
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4))
|
||||
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
||||
)
|
||||
)
|
||||
|
||||
engine = PPInferEngine(
|
||||
engine = CaiInferEngine(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
model_policy=LlamaModelInferPolicy(),
|
||||
new_length=new_length,
|
||||
max_output_len=max_output_len,
|
||||
micro_batch_size=micro_batch_size,
|
||||
)
|
||||
output = engine.inference(inputs)
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
|
||||
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
|
||||
|
||||
|
||||
@parameterize("tp_size", [1])
|
||||
@parameterize("pp_size", [2])
|
||||
@parameterize("new_length", [4, 8, 16])
|
||||
@parameterize("micro_batch_size", [1, 4])
|
||||
@parameterize("max_output_len", [4])
|
||||
@parameterize("micro_batch_size", [1])
|
||||
@clear_cache_before_run()
|
||||
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||
pipeline_inference_test(pp_size, new_length, micro_batch_size)
|
||||
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize("tp_size", [2])
|
||||
@parameterize("pp_size", [2])
|
||||
@parameterize("max_output_len", [4])
|
||||
@parameterize("micro_batch_size", [1])
|
||||
@clear_cache_before_run()
|
||||
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
@ -55,12 +70,18 @@ def check_pipeline_inference(rank, world_size, port):
|
|||
run_pipeline_inference_test()
|
||||
|
||||
|
||||
def check_tp_pipeline_inference(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_tp_pipeline_inference_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_pipeline_inference():
|
||||
spawn(check_pipeline_inference, nprocs=2)
|
||||
spawn(check_tp_pipeline_inference, nprocs=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue