mirror of https://github.com/hpcaitech/ColossalAI
[Refactor] remove useless inference code (#5022)
* remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2pull/5035/head
parent
81b8f5e76a
commit
c6295c3381
@ -0,0 +1,248 @@
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from ..kvcache_manager import BatchInferState, MemoryManager
|
||||
|
||||
__all__ = "MicroBatchManager"
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
PREFILL = 1
|
||||
GENERATE = 2
|
||||
DONE = 3
|
||||
COOLDOWN = 4
|
||||
|
||||
|
||||
class MicroBatchDescription:
|
||||
"""
|
||||
This is the class to record the infomation of each microbatch, and also do some update operation.
|
||||
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
||||
details, please refer to the doc of these two classes blow.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs_dict: Dict[str, torch.Tensor],
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
) -> None:
|
||||
self.mb_length = inputs_dict["input_ids"].shape[-1]
|
||||
self.target_length = self.mb_length + max_output_len
|
||||
self.infer_state = BatchInferState.init_from_batch(
|
||||
batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager
|
||||
)
|
||||
# print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}")
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""
|
||||
Return the state of current micro batch, when current length is equal to target length,
|
||||
the state is DONE, otherwise GENERATE
|
||||
|
||||
"""
|
||||
# TODO: add the condition for early stopping
|
||||
if self.cur_length == self.target_length:
|
||||
return Status.DONE
|
||||
elif self.cur_length == self.target_length - 1:
|
||||
return Status.COOLDOWN
|
||||
else:
|
||||
return Status.GENERATE
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
Return the current sequnence length of micro batch
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class HeadMicroBatchDescription(MicroBatchDescription):
|
||||
"""
|
||||
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
|
||||
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
|
||||
information and the condition to determine the state is different from other stages.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs_dict: Dict[str, torch.Tensor],
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
|
||||
assert inputs_dict is not None
|
||||
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
|
||||
self.input_ids = inputs_dict["input_ids"]
|
||||
self.attn_mask = inputs_dict["attention_mask"]
|
||||
self.new_tokens = None
|
||||
|
||||
def update(self, new_token: torch.Tensor = None):
|
||||
if new_token is not None:
|
||||
self._update_newtokens(new_token)
|
||||
if self.state is not Status.DONE and new_token is not None:
|
||||
self._update_attnmask()
|
||||
|
||||
def _update_newtokens(self, new_token: torch.Tensor):
|
||||
if self.new_tokens is None:
|
||||
self.new_tokens = new_token
|
||||
else:
|
||||
self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1)
|
||||
|
||||
def _update_attnmask(self):
|
||||
self.attn_mask = torch.cat(
|
||||
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
|
||||
)
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token
|
||||
|
||||
"""
|
||||
if self.new_tokens is None:
|
||||
return self.mb_length
|
||||
else:
|
||||
return self.mb_length + len(self.new_tokens[0])
|
||||
|
||||
|
||||
class BodyMicroBatchDescription(MicroBatchDescription):
|
||||
"""
|
||||
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs_dict: Dict[str, torch.Tensor],
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1
|
||||
|
||||
"""
|
||||
return self.infer_state.seq_len.max().item()
|
||||
|
||||
|
||||
class MicroBatchManager:
|
||||
"""
|
||||
MicroBatchManager is a class that manages the micro batch.
|
||||
|
||||
Args:
|
||||
stage (int): stage id of current stage.
|
||||
micro_batch_size (int): the micro batch size.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stage: int,
|
||||
micro_batch_size: int,
|
||||
micro_batch_buffer_size: int,
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager_list: MemoryManager,
|
||||
):
|
||||
self.stage = stage
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.buffer_size = micro_batch_buffer_size
|
||||
self.max_input_len = max_input_len
|
||||
self.max_output_len = max_output_len
|
||||
self.cache_manager_list = cache_manager_list
|
||||
self.mb_descrption_buffer = {}
|
||||
self.new_tokens_buffer = {}
|
||||
self.idx = 0
|
||||
|
||||
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
|
||||
if self.stage == 0:
|
||||
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
|
||||
)
|
||||
else:
|
||||
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
|
||||
)
|
||||
|
||||
def step(self, new_token: torch.Tensor = None):
|
||||
"""
|
||||
Update the state if microbatch manager, 2 conditions.
|
||||
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
|
||||
2. For other conditon, only receive the output of previous stage, and update the descrption.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
new_token (torch.Tensor): the new token generated by current stage.
|
||||
"""
|
||||
# Add descrption first if the descrption is None
|
||||
self.cur_descrption.update(new_token)
|
||||
return self.cur_state
|
||||
|
||||
def export_new_tokens(self):
|
||||
new_tokens_list = []
|
||||
for i in self.mb_descrption_buffer.values():
|
||||
new_tokens_list.extend(i.new_tokens.tolist())
|
||||
return new_tokens_list
|
||||
|
||||
def is_micro_batch_done(self):
|
||||
if len(self.mb_descrption_buffer) == 0:
|
||||
return False
|
||||
for mb in self.mb_descrption_buffer.values():
|
||||
if mb.state != Status.DONE:
|
||||
return False
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
self.mb_descrption_buffer.clear()
|
||||
for cache in self.cache_manager_list:
|
||||
cache.free_all()
|
||||
|
||||
def next(self):
|
||||
self.idx = (self.idx + 1) % self.buffer_size
|
||||
|
||||
def _remove_descrption(self):
|
||||
self.mb_descrption_buffer.pop(self.idx)
|
||||
|
||||
@property
|
||||
def cur_descrption(self) -> MicroBatchDescription:
|
||||
return self.mb_descrption_buffer.get(self.idx)
|
||||
|
||||
@property
|
||||
def cur_infer_state(self):
|
||||
if self.cur_descrption is None:
|
||||
return None
|
||||
return self.cur_descrption.infer_state
|
||||
|
||||
@property
|
||||
def cur_state(self):
|
||||
"""
|
||||
Return the state of current micro batch, when current descrption is None, the state is PREFILL
|
||||
|
||||
"""
|
||||
if self.cur_descrption is None:
|
||||
return Status.PREFILL
|
||||
return self.cur_descrption.state
|
@ -1,4 +1,5 @@
|
||||
from .bloom import BloomInferenceForwards
|
||||
from .chatglm2 import ChatGLM2InferenceForwards
|
||||
from .llama import LlamaInferenceForwards
|
||||
|
||||
__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards"]
|
||||
__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards", "ChatGLM2InferenceForwards"]
|
||||
|
@ -1,5 +1,5 @@
|
||||
from .bloom import BloomModelInferPolicy
|
||||
from .chatglm import ChatGLM2InferPolicy
|
||||
from .chatglm2 import ChatGLM2InferPolicy
|
||||
from .llama import LlamaModelInferPolicy
|
||||
|
||||
__all__ = ["LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
|
||||
|
@ -0,0 +1,2 @@
|
||||
from .batch_infer_state import BatchInferState
|
||||
from .kvcache_manager import MemoryManager
|
@ -0,0 +1,4 @@
|
||||
from .hybridengine import CaiInferEngine
|
||||
from .hybridengine.polices import LlamaModelInferPolicy
|
||||
|
||||
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]
|
@ -0,0 +1,3 @@
|
||||
from .engine import CaiInferEngine
|
||||
|
||||
__all__ = ["CaiInferEngine"]
|
@ -0,0 +1,170 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.schedule.generate import GenerateSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from ..pipeline.microbatch_manager import MicroBatchManager
|
||||
from ..tensor_parallel.kvcache_manager import MemoryManager
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = [
|
||||
"LlamaForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
class CaiInferEngine:
|
||||
"""
|
||||
CaiInferEngine is a class that handles the pipeline parallel inference.
|
||||
|
||||
Args:
|
||||
tp_size (int): the size of tensor parallelism.
|
||||
pp_size (int): the size of pipeline parallelism.
|
||||
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
|
||||
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
|
||||
micro_batch_size (int): the micro batch size.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
max_batch_size (int): the maximum batch size.
|
||||
max_input_len (int): the maximum input length.
|
||||
max_output_len (int): the maximum output length.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from colossalai.inference import InferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
|
||||
# assume the model is infered with 2 pipeline stages
|
||||
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
|
||||
|
||||
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
|
||||
data = tokenizer(input, return_tensors='pt')
|
||||
output = inferengine.inference([data.to('cuda').data])
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tp_size: int = 1,
|
||||
pp_size: int = 1,
|
||||
dtype: str = "fp16",
|
||||
model: nn.Module = None,
|
||||
model_policy: Policy = None,
|
||||
micro_batch_size: int = 1,
|
||||
micro_batch_buffer_size: int = None,
|
||||
max_batch_size: int = 4,
|
||||
max_input_len: int = 32,
|
||||
max_output_len: int = 32,
|
||||
verbose: bool = False,
|
||||
# TODO: implement early_stopping, and various gerneration options
|
||||
early_stopping: bool = False,
|
||||
do_sample: bool = False,
|
||||
num_beams: int = 1,
|
||||
) -> None:
|
||||
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
|
||||
assert (
|
||||
tp_size * pp_size == dist.get_world_size()
|
||||
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
assert model and model_policy, "Model with model_policy should be provided."
|
||||
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
||||
|
||||
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
|
||||
|
||||
# TODO: support only tensor parallel inference
|
||||
assert pp_size > 1, "Not support only tensor parallel inference."
|
||||
self.pp_size = pp_size
|
||||
self.tp_size = tp_size
|
||||
|
||||
if dtype == "fp16":
|
||||
self.dtype = torch.float16
|
||||
model.half()
|
||||
elif dtype == "bf16":
|
||||
self.dtype = torch.bfloat16
|
||||
model.to(torch.bfloat16)
|
||||
else:
|
||||
self.dtype = torch.float32
|
||||
|
||||
# Init pg mesh
|
||||
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
|
||||
|
||||
stage_manager = None
|
||||
if pp_size > 1:
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
|
||||
self.cache_manager_list = [
|
||||
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
|
||||
for _ in range(micro_batch_buffer_size or pp_size)
|
||||
]
|
||||
self.mb_manager = MicroBatchManager(
|
||||
stage_manager.stage,
|
||||
micro_batch_size,
|
||||
micro_batch_buffer_size or pp_size,
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
self.cache_manager_list,
|
||||
)
|
||||
self.verbose = verbose
|
||||
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
||||
|
||||
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
|
||||
|
||||
def inference(self, input_list):
|
||||
"""
|
||||
Args:
|
||||
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
|
||||
|
||||
Returns:
|
||||
out (list): a list of output data, each element is a list of token.
|
||||
timestamp (float): the time cost of the inference, only return when verbose is `True`.
|
||||
"""
|
||||
assert isinstance(
|
||||
input_list, (BatchEncoding, dict)
|
||||
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
|
||||
if isinstance(input_list, BatchEncoding):
|
||||
input_list = input_list.data
|
||||
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
|
||||
if self.verbose:
|
||||
return out, timestamp
|
||||
else:
|
||||
return out
|
||||
|
||||
def _shardformer(self, model, model_policy, stage_manager, tp_group):
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=False,
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model.cuda()
|
||||
|
||||
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
|
||||
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
head_num = model.config.num_attention_heads
|
||||
num_hidden_layers = (
|
||||
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
|
||||
)
|
||||
layer_num = num_hidden_layers // self.pp_size
|
||||
|
||||
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
|
||||
return cache_manager
|
@ -0,0 +1,3 @@
|
||||
from .llama import LlamaInferenceForwards
|
||||
|
||||
__all__ = ["LlamaInferenceForwards"]
|
@ -0,0 +1,489 @@
|
||||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import copy_kv_to_mem_cache
|
||||
|
||||
try:
|
||||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_context_attention_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
|
||||
HAS_FLASH_KERNEL = True
|
||||
except:
|
||||
HAS_FLASH_KERNEL = False
|
||||
print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def llama_triton_context_attention(
|
||||
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
|
||||
):
|
||||
if num_key_value_groups == 1:
|
||||
if HAS_LIGHTLLM_KERNEL is False:
|
||||
llama_context_attn_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
else:
|
||||
lightllm_context_attention_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
else:
|
||||
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
|
||||
lightllm_llama2_context_attention_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
|
||||
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
|
||||
if num_key_value_groups == 1:
|
||||
token_attention_fwd(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
else:
|
||||
Llama2TokenAttentionForwards.token_attn(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
infer_state.other_kv_index,
|
||||
)
|
||||
|
||||
|
||||
class LlamaInferenceForwards:
|
||||
"""
|
||||
This class holds forwards for llama inference.
|
||||
We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
infer_state: BatchInferState = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
# If is first stage and after warmup, go throught lm_head first
|
||||
if stage_manager.is_first_stage() and hidden_states is not None:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
return {"logits": lm_logits}
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = LlamaInferenceForwards.llama_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
infer_state=infer_state,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
infer_state: BatchInferState = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
assert stage_manager is not None
|
||||
assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}"
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
past_key_values_length = 0
|
||||
else:
|
||||
past_key_values_length = infer_state.max_len_in_batch - 1
|
||||
|
||||
# NOTE: differentiate with prefill stage
|
||||
# block_loc require different value-assigning method for two different stage
|
||||
if use_cache and seq_length != 1:
|
||||
# NOTE assume prefill stage
|
||||
# allocate memory block
|
||||
infer_state.is_context_stage = True # set prefill stage, notify attention layer
|
||||
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
|
||||
infer_state.init_block_loc(
|
||||
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
|
||||
)
|
||||
else:
|
||||
infer_state.is_context_stage = False
|
||||
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
|
||||
if alloc_mem is not None:
|
||||
infer_state.decode_is_contiguous = True
|
||||
infer_state.decode_mem_index = alloc_mem[0]
|
||||
infer_state.decode_mem_start = alloc_mem[1]
|
||||
infer_state.decode_mem_end = alloc_mem[2]
|
||||
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
|
||||
else:
|
||||
infer_state.decode_is_contiguous = False
|
||||
alloc_mem = infer_state.cache_manager.alloc(batch_size)
|
||||
infer_state.decode_mem_index = alloc_mem
|
||||
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.repeat(batch_size, 1)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
|
||||
else:
|
||||
seq_len = infer_state.seq_len
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||
infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
infer_state.decode_layer_id = 0
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * (end_idx - start_idx + 1))
|
||||
|
||||
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
|
||||
decoder_layer = self.layers[idx]
|
||||
# NOTE: modify here for passing args to decoder layer
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
infer_state.decode_layer_id += 1
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# update indices
|
||||
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.seq_len += 1
|
||||
infer_state.max_len_in_batch += 1
|
||||
|
||||
# if not return_dict:
|
||||
# return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
# return BaseModelOutputWithPast(
|
||||
# last_hidden_state=hidden_states,
|
||||
# past_key_values=next_cache,
|
||||
# hidden_states=all_hidden_states,
|
||||
# attentions=all_self_attns,
|
||||
# )
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def llama_decoder_layer_forward(
|
||||
self: LlamaDecoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def llama_flash_attn_kvcache_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
assert use_cache is True, "use_cache should be set to True using this llama attention"
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# NOTE might think about better way to handle transposed k and v
|
||||
# key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
|
||||
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# NOTE might want to revise
|
||||
# need some way to record the length of past key values cache
|
||||
# since we won't return past_key_value_cache right now
|
||||
|
||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||
|
||||
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
|
||||
|
||||
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
# first token generation
|
||||
# copy key and value calculated in current step to memory manager
|
||||
copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_states,
|
||||
value_states,
|
||||
infer_state.context_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
llama_triton_context_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state,
|
||||
num_key_value_groups=self.num_key_value_groups,
|
||||
)
|
||||
else:
|
||||
if infer_state.decode_is_contiguous:
|
||||
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||||
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
|
||||
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_k.copy_(key_states)
|
||||
cache_v.copy_(value_states)
|
||||
else:
|
||||
# if decode is not contiguous, use triton kernel to copy key and value cache
|
||||
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
|
||||
copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_states,
|
||||
value_states,
|
||||
infer_state.decode_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
if HAS_LIGHTLLM_KERNEL:
|
||||
attn_output = torch.empty_like(query_states)
|
||||
llama_triton_token_attention(
|
||||
query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
|
||||
)
|
||||
else:
|
||||
self.num_heads // self.num_key_value_heads
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
|
||||
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
|
||||
copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
|
||||
copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
attn_output = flash_attn_with_kvcache(
|
||||
q=query_states,
|
||||
k_cache=copy_cache_k,
|
||||
v_cache=copy_cache_v,
|
||||
softmax_scale=1 / math.sqrt(self.head_dim),
|
||||
causal=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
# return past_key_value as None
|
||||
return attn_output, None, None
|
@ -0,0 +1,3 @@
|
||||
from .llama import LlamaModelInferPolicy
|
||||
|
||||
__all__ = ["LlamaModelInferPolicy"]
|
@ -0,0 +1,142 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
from ..modeling._utils import init_to_get_rotary
|
||||
from ..modeling.llama import LlamaInferenceForwards
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import rmsnorm_forward
|
||||
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
HAS_TRITON_RMSNORM = False
|
||||
|
||||
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.inference_gptq:
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self.shard_config._infer()
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_model_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||
)
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaAttention
|
||||
)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy
|
||||
)
|
||||
infer_forward = None
|
||||
if HAS_TRITON_RMSNORM:
|
||||
infer_forward = get_triton_rmsnorm_forward()
|
||||
|
||||
if infer_forward is not None:
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
init_to_get_rotary(self.model.model)
|
||||
return self.model
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
@ -0,0 +1,4 @@
|
||||
from .cai_gptq import HAS_AUTO_GPTQ
|
||||
|
||||
if HAS_AUTO_GPTQ:
|
||||
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
|
@ -0,0 +1,14 @@
|
||||
import warnings
|
||||
|
||||
HAS_AUTO_GPTQ = False
|
||||
try:
|
||||
import auto_gptq
|
||||
|
||||
HAS_AUTO_GPTQ = True
|
||||
except ImportError:
|
||||
warnings.warn("please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ")
|
||||
HAS_AUTO_GPTQ = False
|
||||
|
||||
if HAS_AUTO_GPTQ:
|
||||
from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear
|
||||
from .gptq_op import CaiGPTQLinearOp
|
@ -0,0 +1,354 @@
|
||||
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import ParallelModule
|
||||
|
||||
from .gptq_op import CaiGPTQLinearOp
|
||||
|
||||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn("CUDA gptq is not installed")
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
|
||||
class CaiQuantLinear(nn.Module):
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||
|
||||
self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||||
self.register_buffer(
|
||||
"qzeros",
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
|
||||
)
|
||||
self.register_buffer(
|
||||
"scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
if row_split:
|
||||
self.register_buffer(
|
||||
"g_idx",
|
||||
torch.tensor(
|
||||
[(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)
|
||||
|
||||
self.q4 = None
|
||||
self.empty_tensor = torch.empty((1, 1), device="meta")
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.row_split = row_split
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
g_idx = (
|
||||
g_idx.clone()
|
||||
if g_idx is not None
|
||||
else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
||||
)
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
half_scales = scales.clone().half()
|
||||
# print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape)
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
pbits = 32
|
||||
ptype = torch.int32
|
||||
unsign_type = np.uint32
|
||||
sign_type = np.int32
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
|
||||
:, None
|
||||
]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(unsign_type)
|
||||
qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type)
|
||||
|
||||
i = 0
|
||||
row = 0
|
||||
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (pbits // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += pbits // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
qweight = qweight.astype(sign_type)
|
||||
qweight1 = torch.from_numpy(qweight)
|
||||
qweight1 = qweight1.contiguous() # .to("cuda")
|
||||
self.qweight.data.copy_(qweight1)
|
||||
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(unsign_type)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (pbits // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += pbits // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
qzeros = qzeros.astype(sign_type)
|
||||
qzeros = torch.from_numpy(qzeros)
|
||||
qzeros = qzeros
|
||||
self.qzeros.data.copy_(qzeros)
|
||||
|
||||
if torch.equal(self.g_idx.to(g_idx.device), g_idx):
|
||||
self.g_idx = None
|
||||
else:
|
||||
self.g_idx = g_idx
|
||||
|
||||
def init_q4(self):
|
||||
assert self.qweight.device.type == "cuda"
|
||||
self.q4_width = self.qweight.shape[1]
|
||||
if self.g_idx is not None:
|
||||
if self.row_split and torch.equal(
|
||||
self.g_idx,
|
||||
torch.tensor(
|
||||
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
|
||||
dtype=torch.int32,
|
||||
device=self.g_idx.device,
|
||||
),
|
||||
):
|
||||
self.g_idx = None
|
||||
elif torch.equal(
|
||||
self.g_idx,
|
||||
torch.tensor(
|
||||
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
|
||||
),
|
||||
):
|
||||
self.g_idx = None
|
||||
|
||||
if self.g_idx is not None:
|
||||
g_idx = self.g_idx.to("cpu")
|
||||
else:
|
||||
g_idx = self.empty_tensor
|
||||
|
||||
self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device())
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def forward(self, x):
|
||||
outshape = x.shape[:-1] + (self.outfeatures,)
|
||||
|
||||
if HAS_GPTQ_CUDA and self.bits == 4:
|
||||
if self.q4 is None:
|
||||
self.init_q4()
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
|
||||
gptq_cuda.q4_matmul(x.half(), self.q4, output)
|
||||
if self.bias is not None and (not self.row_split or self.tp_size == 1):
|
||||
output.add_(self.bias)
|
||||
else:
|
||||
if self.bias is not None and (not self.row_split or self.tp_size == 1):
|
||||
bias = self.bias
|
||||
else:
|
||||
bias = None
|
||||
output = self.gptq_linear(
|
||||
x,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
g_idx=self.g_idx,
|
||||
bias=bias,
|
||||
)
|
||||
return output.view(outshape)
|
||||
|
||||
|
||||
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
|
||||
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
|
||||
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
|
||||
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
|
||||
g_idx = gptq_linear.g_idx
|
||||
if gptq_linear.bias is not None:
|
||||
bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1)
|
||||
|
||||
cai_split_out_features = cai_linear.outfeatures // split_num
|
||||
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
|
||||
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||
]
|
||||
cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
|
||||
:, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
|
||||
]
|
||||
cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
|
||||
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||
]
|
||||
if cai_linear.bias is not None:
|
||||
cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
|
||||
tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||
]
|
||||
|
||||
cai_linear.g_idx.copy_(g_idx)
|
||||
|
||||
|
||||
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
||||
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
|
||||
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
|
||||
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
|
||||
g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0)
|
||||
|
||||
cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num
|
||||
zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num
|
||||
idx_split_features = cai_linear.infeatures // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
|
||||
tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
|
||||
]
|
||||
cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
|
||||
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
|
||||
]
|
||||
cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
|
||||
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
|
||||
]
|
||||
cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
|
||||
tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
|
||||
]
|
||||
if cai_linear.bias is not None:
|
||||
cai_linear.bias.copy_(gptq_linear.bias)
|
||||
|
||||
|
||||
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
super().__init__(
|
||||
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
|
||||
)
|
||||
self.process_group = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if in_features < tp_size:
|
||||
return module
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = RowCaiQuantLinear(
|
||||
module.bits,
|
||||
module.group_size,
|
||||
module.in_features // tp_size,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=True,
|
||||
)
|
||||
linear_1d.process_group = process_group
|
||||
|
||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
return linear_1d
|
||||
|
||||
def forward(self, x):
|
||||
output = super().forward(x)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||
if self.bias is not None:
|
||||
output.add_(self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
super().__init__(
|
||||
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
|
||||
)
|
||||
self.process_group = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if in_features < tp_size:
|
||||
return module
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = ColCaiQuantLinear(
|
||||
module.bits,
|
||||
module.group_size,
|
||||
module.in_features,
|
||||
module.out_features // tp_size,
|
||||
module.bias is not None,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
linear_1d.process_group = process_group
|
||||
|
||||
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
return linear_1d
|
@ -0,0 +1,58 @@
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.triton import gptq_fused_linear_triton
|
||||
|
||||
|
||||
class CaiGPTQLinearOp(torch.nn.Module):
|
||||
def __init__(self, gptq_group_size, gptq_quant_bits):
|
||||
super(CaiGPTQLinearOp, self).__init__()
|
||||
self.group_size = gptq_group_size
|
||||
self.bits = gptq_quant_bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scales: torch.Tensor,
|
||||
weight_zeros: torch.Tensor,
|
||||
g_idx: torch.Tensor = None,
|
||||
act_type=0,
|
||||
bias: torch.Tensor = None,
|
||||
residual: torch.Tensor = None,
|
||||
qkv_fused=False,
|
||||
):
|
||||
add_bias = True
|
||||
if bias is None:
|
||||
bias = self.empty_tensor
|
||||
add_bias = False
|
||||
|
||||
add_residual = True
|
||||
if residual is None:
|
||||
residual = self.empty_tensor
|
||||
add_residual = False
|
||||
x = input.view(-1, input.shape[-1])
|
||||
|
||||
out = gptq_fused_linear_triton(
|
||||
x,
|
||||
weight,
|
||||
weight_scales,
|
||||
weight_zeros,
|
||||
bias,
|
||||
residual,
|
||||
self.bits,
|
||||
self.maxq,
|
||||
self.group_size,
|
||||
qkv_fused,
|
||||
add_bias,
|
||||
add_residual,
|
||||
act_type=act_type,
|
||||
g_idx=g_idx,
|
||||
)
|
||||
if qkv_fused:
|
||||
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
|
||||
else:
|
||||
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
|
||||
|
||||
return out
|
@ -0,0 +1,12 @@
|
||||
try:
|
||||
import torch_int
|
||||
|
||||
HAS_TORCH_INT = True
|
||||
except ImportError:
|
||||
HAS_TORCH_INT = False
|
||||
raise ImportError(
|
||||
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
|
||||
)
|
||||
|
||||
if HAS_TORCH_INT:
|
||||
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
|
@ -0,0 +1,487 @@
|
||||
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ
|
||||
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
|
||||
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from os.path import isdir, isfile, join
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import accelerate
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from safetensors.torch import save_file as safe_save
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
from transformers.utils.generic import ContextManagers
|
||||
from transformers.utils.hub import PushToHubMixin, cached_file
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
|
||||
|
||||
SUPPORTED_MODELS = ["llama"]
|
||||
|
||||
|
||||
class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
|
||||
layer_type: str = None
|
||||
|
||||
def __init__(self, model: PreTrainedModel, quantized: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.model_type = self.model.config.model_type
|
||||
self._quantized = quantized
|
||||
self.config = self.model.config
|
||||
self.cache_manager = None
|
||||
self.max_total_token_num = 0
|
||||
|
||||
@property
|
||||
def quantized(self):
|
||||
return self._quantized
|
||||
|
||||
def init_cache_manager(self, max_total_token_num=2048):
|
||||
if self.config.model_type == "llama":
|
||||
head_num = self.config.num_key_value_heads
|
||||
layer_num = self.config.num_hidden_layers
|
||||
head_dim = self.config.hidden_size // head_num
|
||||
|
||||
self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
|
||||
self.max_total_token_num = max_total_token_num
|
||||
|
||||
def init_batch_state(self, max_output_len=256, **kwargs):
|
||||
input_ids = kwargs["input_ids"]
|
||||
batch_size = len(input_ids)
|
||||
|
||||
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
start_index = 0
|
||||
max_len_in_batch = -1
|
||||
|
||||
for i in range(batch_size):
|
||||
seq_len = len(input_ids[i])
|
||||
seq_lengths[i] = seq_len
|
||||
seq_start_indexes[i] = start_index
|
||||
start_index += seq_len
|
||||
max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch
|
||||
|
||||
if "max_total_token_num" in kwargs.keys():
|
||||
max_total_token_num = kwargs["max_total_token_num"]
|
||||
self.init_cache_manager(max_total_token_num)
|
||||
|
||||
if "max_new_tokens" in kwargs.keys():
|
||||
max_output_len = kwargs["max_new_tokens"]
|
||||
|
||||
if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num:
|
||||
max_total_token_num = batch_size * (max_len_in_batch + max_output_len)
|
||||
warnings.warn(f"reset max tokens to {max_total_token_num}")
|
||||
self.init_cache_manager(max_total_token_num)
|
||||
|
||||
block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda")
|
||||
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
|
||||
batch_infer_state.seq_len = seq_lengths.to("cuda")
|
||||
batch_infer_state.start_loc = seq_start_indexes.to("cuda")
|
||||
batch_infer_state.block_loc = block_loc
|
||||
batch_infer_state.decode_layer_id = 0
|
||||
batch_infer_state.is_context_stage = True
|
||||
batch_infer_state.set_cache_manager(self.cache_manager)
|
||||
batch_infer_state.cache_manager.free_all()
|
||||
return batch_infer_state
|
||||
|
||||
@abstractmethod
|
||||
@torch.inference_mode()
|
||||
def quantize(
|
||||
self,
|
||||
examples: List[Dict[str, Union[List[int], torch.LongTensor]]],
|
||||
):
|
||||
if self.quantized:
|
||||
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""shortcut for model.generate"""
|
||||
|
||||
batch_infer_state = self.init_batch_state(**kwargs)
|
||||
if self.config.model_type == "llama":
|
||||
setattr(self.model.model, "infer_state", batch_infer_state)
|
||||
|
||||
with torch.inference_mode():
|
||||
return self.model.generate(**kwargs)
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
"""shortcut for model.prepare_inputs_for_generation"""
|
||||
return self.model.prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512):
|
||||
for text in tqdm(dataset):
|
||||
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
|
||||
model(input_ids)
|
||||
|
||||
def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512):
|
||||
pbar = tqdm(dataset)
|
||||
for text in pbar:
|
||||
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
|
||||
model(input_ids)
|
||||
mean_scale = np.mean([v["input"] for v in act_dict.values()])
|
||||
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
|
||||
|
||||
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
|
||||
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
act_scales = {}
|
||||
|
||||
def stat_tensor(name, tensor):
|
||||
hidden_dim = tensor.shape[-1]
|
||||
tensor = tensor.view(-1, hidden_dim).abs().detach()
|
||||
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
|
||||
if name in act_scales:
|
||||
act_scales[name] = torch.max(act_scales[name], comming_max)
|
||||
else:
|
||||
act_scales[name] = comming_max
|
||||
|
||||
def stat_input_hook(m, x, y, name):
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
stat_tensor(name, x)
|
||||
|
||||
hooks = []
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name)))
|
||||
|
||||
self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
return act_scales
|
||||
|
||||
# Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
|
||||
@torch.no_grad()
|
||||
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
|
||||
if not isinstance(fcs, list):
|
||||
fcs = [fcs]
|
||||
for fc in fcs:
|
||||
assert isinstance(fc, nn.Linear)
|
||||
assert ln.weight.numel() == fc.in_features == act_scales.numel()
|
||||
|
||||
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
|
||||
act_scales = act_scales.to(device=device, dtype=dtype)
|
||||
weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
|
||||
|
||||
scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
|
||||
|
||||
ln.weight.div_(scales)
|
||||
if hasattr(ln, "bias"):
|
||||
ln.bias.div_(scales)
|
||||
|
||||
for fc in fcs:
|
||||
fc.weight.mul_(scales.view(1, -1))
|
||||
|
||||
@classmethod
|
||||
def create_quantized_model(model):
|
||||
raise NotImplementedError("Not implement create_quantized_model method")
|
||||
|
||||
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
|
||||
def save_quantized(
|
||||
self,
|
||||
save_dir: str,
|
||||
model_basename: str,
|
||||
use_safetensors: bool = False,
|
||||
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""save quantized model and configs to local disk"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if not self.quantized:
|
||||
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
|
||||
|
||||
self.model.to("cpu")
|
||||
|
||||
model_base_name = model_basename # or f"smooth-"
|
||||
if use_safetensors:
|
||||
model_save_name = model_base_name + ".safetensors"
|
||||
state_dict = self.model.state_dict()
|
||||
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
||||
if safetensors_metadata is None:
|
||||
safetensors_metadata = {}
|
||||
elif not isinstance(safetensors_metadata, dict):
|
||||
raise TypeError("safetensors_metadata must be a dictionary.")
|
||||
else:
|
||||
print(f"Received safetensors_metadata: {safetensors_metadata}")
|
||||
new_safetensors_metadata = {}
|
||||
converted_keys = False
|
||||
for key, value in safetensors_metadata.items():
|
||||
if not isinstance(key, str) or not isinstance(value, str):
|
||||
converted_keys = True
|
||||
try:
|
||||
new_key = str(key)
|
||||
new_value = str(value)
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
|
||||
)
|
||||
if new_key in new_safetensors_metadata:
|
||||
print(
|
||||
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
|
||||
)
|
||||
new_safetensors_metadata[new_key] = new_value
|
||||
safetensors_metadata = new_safetensors_metadata
|
||||
if converted_keys:
|
||||
print(
|
||||
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
|
||||
)
|
||||
|
||||
# Format is required to enable Accelerate to load the metadata
|
||||
# otherwise it raises an OSError
|
||||
safetensors_metadata["format"] = "pt"
|
||||
|
||||
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
|
||||
else:
|
||||
model_save_name = model_base_name + ".bin"
|
||||
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
|
||||
|
||||
self.model.config.save_pretrained(save_dir)
|
||||
|
||||
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_dir: str,
|
||||
use_safetensors: bool = False,
|
||||
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""alias of save_quantized"""
|
||||
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
|
||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||
|
||||
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
max_memory: Optional[dict] = None,
|
||||
trust_remote_code: bool = False,
|
||||
torch_dtype: torch.dtype = torch.float16,
|
||||
**model_init_kwargs,
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
|
||||
|
||||
def skip(*args, **kwargs):
|
||||
pass
|
||||
|
||||
torch.nn.init.kaiming_uniform_ = skip
|
||||
torch.nn.init.uniform_ = skip
|
||||
torch.nn.init.normal_ = skip
|
||||
|
||||
# Parameters related to loading from Hugging Face Hub
|
||||
cache_dir = model_init_kwargs.pop("cache_dir", None)
|
||||
force_download = model_init_kwargs.pop("force_download", False)
|
||||
resume_download = model_init_kwargs.pop("resume_download", False)
|
||||
proxies = model_init_kwargs.pop("proxies", None)
|
||||
local_files_only = model_init_kwargs.pop("local_files_only", False)
|
||||
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
|
||||
revision = model_init_kwargs.pop("revision", None)
|
||||
subfolder = model_init_kwargs.pop("subfolder", "")
|
||||
model_init_kwargs.pop("_commit_hash", None)
|
||||
|
||||
cached_file_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"resume_download": resume_download,
|
||||
"local_files_only": local_files_only,
|
||||
"use_auth_token": use_auth_token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
}
|
||||
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs)
|
||||
if config.model_type not in SUPPORTED_MODELS:
|
||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||
|
||||
# enforce some values despite user specified
|
||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
model_init_kwargs["trust_remote_code"] = trust_remote_code
|
||||
if max_memory:
|
||||
if "disk" in max_memory:
|
||||
raise NotImplementedError("disk offload not support yet.")
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
model.tie_weights()
|
||||
|
||||
max_memory = accelerate.utils.get_balanced_memory(
|
||||
model,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=[cls.layer_type],
|
||||
dtype=model_init_kwargs["torch_dtype"],
|
||||
low_zero=False,
|
||||
)
|
||||
model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
|
||||
model,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=[cls.layer_type],
|
||||
dtype=model_init_kwargs["torch_dtype"],
|
||||
)
|
||||
model_init_kwargs["low_cpu_mem_usage"] = True
|
||||
|
||||
del model
|
||||
else:
|
||||
model_init_kwargs["device_map"] = None
|
||||
model_init_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
|
||||
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)
|
||||
|
||||
model_config = model.config.to_dict()
|
||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||
if any([k in model_config for k in seq_len_keys]):
|
||||
for key in seq_len_keys:
|
||||
if key in model_config:
|
||||
model.seqlen = model_config[key]
|
||||
break
|
||||
else:
|
||||
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
|
||||
model.seqlen = 4096
|
||||
model.eval()
|
||||
|
||||
return cls(model, False)
|
||||
|
||||
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
|
||||
@classmethod
|
||||
def from_quantized(
|
||||
cls,
|
||||
model_name_or_path: Optional[str],
|
||||
model_basename: Optional[str] = None,
|
||||
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
|
||||
max_memory: Optional[dict] = None,
|
||||
device: Optional[Union[str, int]] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
use_safetensors: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""load quantized model from local disk"""
|
||||
|
||||
# Parameters related to loading from Hugging Face Hub
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
cached_file_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"resume_download": resume_download,
|
||||
"local_files_only": local_files_only,
|
||||
"use_auth_token": use_auth_token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
"_raise_exceptions_for_missing_entries": False,
|
||||
"_commit_hash": commit_hash,
|
||||
}
|
||||
|
||||
# == step1: prepare configs and file names == #
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs
|
||||
)
|
||||
|
||||
if config.model_type not in SUPPORTED_MODELS:
|
||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||
|
||||
extensions = []
|
||||
if use_safetensors:
|
||||
extensions.append(".safetensors")
|
||||
else:
|
||||
extensions += [".bin", ".pt"]
|
||||
|
||||
model_name_or_path = str(model_name_or_path)
|
||||
is_local = isdir(model_name_or_path)
|
||||
|
||||
resolved_archive_file = None
|
||||
if is_local:
|
||||
model_save_name = join(model_name_or_path, model_basename)
|
||||
for ext in extensions:
|
||||
if isfile(model_save_name + ext):
|
||||
resolved_archive_file = model_save_name + ext
|
||||
break
|
||||
else: # remote
|
||||
for ext in extensions:
|
||||
resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
|
||||
if resolved_archive_file is not None:
|
||||
break
|
||||
|
||||
if resolved_archive_file is None: # Could not find a model file to use
|
||||
raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
|
||||
|
||||
model_save_name = resolved_archive_file
|
||||
|
||||
# == step2: convert model to quantized-model (replace Linear) == #
|
||||
def skip(*args, **kwargs):
|
||||
pass
|
||||
|
||||
torch.nn.init.kaiming_uniform_ = skip
|
||||
torch.nn.init.uniform_ = skip
|
||||
torch.nn.init.normal_ = skip
|
||||
|
||||
transformers.modeling_utils._init_weights = False
|
||||
|
||||
init_contexts = [no_init_weights()]
|
||||
if low_cpu_mem_usage:
|
||||
init_contexts.append(accelerate.init_empty_weights(include_buffers=True))
|
||||
|
||||
with ContextManagers(init_contexts):
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype
|
||||
)
|
||||
cls.create_quantized_model(model)
|
||||
model.tie_weights()
|
||||
|
||||
# == step3: load checkpoint to quantized-model == #
|
||||
accelerate.utils.modeling.load_checkpoint_in_model(
|
||||
model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True
|
||||
)
|
||||
|
||||
# == step4: set seqlen == #
|
||||
model_config = model.config.to_dict()
|
||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||
if any([k in model_config for k in seq_len_keys]):
|
||||
for key in seq_len_keys:
|
||||
if key in model_config:
|
||||
model.seqlen = model_config[key]
|
||||
break
|
||||
else:
|
||||
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
|
||||
model.seqlen = 4096
|
||||
|
||||
return cls(
|
||||
model,
|
||||
True,
|
||||
)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return super().__getattr__(item)
|
||||
except:
|
||||
return getattr(self.model, item)
|
||||
|
||||
|
||||
__all__ = ["BaseSmoothForCausalLM"]
|
@ -0,0 +1,179 @@
|
||||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
|
||||
import torch
|
||||
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
|
||||
from torch_int.functional.quantization import quantize_per_tensor_absmax
|
||||
|
||||
try:
|
||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||
|
||||
smoothquant_cuda = SmoothquantBuilder().load()
|
||||
HAS_SMOOTHQUANT_CUDA = True
|
||||
except ImportError:
|
||||
HAS_SMOOTHQUANT_CUDA = False
|
||||
raise ImportError("CUDA smoothquant linear is not installed")
|
||||
|
||||
|
||||
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(self.out_features, self.in_features),
|
||||
dtype=torch.int8,
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0)
|
||||
y = y.view(*x_shape[:-1], -1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.Linear, input_scale):
|
||||
int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features)
|
||||
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||
alpha = input_scale * weight_scale
|
||||
int8_module.weight = int8_weight
|
||||
if module.bias is not None:
|
||||
int8_module.bias.data.copy_(module.bias.to(torch.float))
|
||||
int8_module.a = alpha
|
||||
return int8_module
|
||||
|
||||
|
||||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
class W8A8B8O8Linear(torch.nn.Module):
|
||||
# For qkv_proj
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(self.out_features, self.in_features),
|
||||
dtype=torch.int8,
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
self.register_buffer("b", torch.tensor(beta))
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item())
|
||||
y = y.view(*x_shape[:-1], -1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.Linear, input_scale, output_scale):
|
||||
int8_module = W8A8B8O8Linear(module.in_features, module.out_features)
|
||||
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||
alpha = input_scale * weight_scale / output_scale
|
||||
int8_module.weight = int8_weight
|
||||
int8_module.a = alpha
|
||||
|
||||
if module.bias is not None:
|
||||
int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias)
|
||||
int8_module.bias = int8_bias
|
||||
beta = bias_scale / output_scale
|
||||
int8_module.b = beta
|
||||
|
||||
return int8_module
|
||||
|
||||
|
||||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
class W8A8BFP32OFP32Linear(torch.nn.Module):
|
||||
# For fc2 and out_proj
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(self.out_features, self.in_features),
|
||||
dtype=torch.int8,
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
|
||||
def _apply(self, fn):
|
||||
# prevent the bias from being converted to half
|
||||
super()._apply(fn)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1)
|
||||
y = y.view(*x_shape[:-1], -1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.Linear, input_scale):
|
||||
int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features)
|
||||
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||
alpha = input_scale * weight_scale
|
||||
int8_module.weight = int8_weight
|
||||
int8_module.a = alpha
|
||||
int8_module.input_scale = input_scale
|
||||
int8_module.weight_scale = weight_scale
|
||||
|
||||
if module.bias is not None:
|
||||
int8_module.bias = module.bias.to(torch.float32)
|
||||
|
||||
return int8_module
|
@ -0,0 +1,838 @@
|
||||
import math
|
||||
import os
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LLAMA_INPUTS_DOCSTRING,
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaMLP,
|
||||
LlamaRotaryEmbedding,
|
||||
repeat_kv,
|
||||
rotate_half,
|
||||
)
|
||||
from transformers.utils import add_start_docstrings_to_model_forward
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import (
|
||||
copy_kv_cache_to_dest,
|
||||
int8_rotary_embedding_fwd,
|
||||
smooth_llama_context_attn_fwd,
|
||||
smooth_token_attention_fwd,
|
||||
)
|
||||
|
||||
from .base_model import BaseSmoothForCausalLM
|
||||
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
||||
|
||||
|
||||
class LLamaSmoothquantAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
if (self.head_dim * num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
|
||||
self.qk_bmm = BMM_S8T_S8N_F32T(1.0)
|
||||
self.pv_bmm = BMM_S8T_S8N_S8T(1.0)
|
||||
|
||||
self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||||
self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||||
self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||||
self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size)
|
||||
|
||||
self.register_buffer("q_output_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("k_output_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("v_output_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("q_rotary_output_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("k_rotary_output_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("out_input_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("attn_input_scale", torch.tensor([1.0]))
|
||||
|
||||
self._init_rope()
|
||||
self.num_key_value_heads = num_heads
|
||||
|
||||
def _init_rope(self):
|
||||
self.rotary_emb = LlamaRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000.0,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pack(
|
||||
module: LlamaAttention,
|
||||
attn_input_scale: float,
|
||||
q_output_scale: float,
|
||||
k_output_scale: float,
|
||||
v_output_scale: float,
|
||||
q_rotary_output_scale: float,
|
||||
k_rotary_output_scale: float,
|
||||
out_input_scale: float,
|
||||
):
|
||||
int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads)
|
||||
|
||||
int8_module.attn_input_scale = torch.tensor([attn_input_scale])
|
||||
|
||||
int8_module.q_output_scale = torch.tensor([q_output_scale])
|
||||
int8_module.k_output_scale = torch.tensor([k_output_scale])
|
||||
int8_module.v_output_scale = torch.tensor([v_output_scale])
|
||||
|
||||
int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale])
|
||||
int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale])
|
||||
|
||||
int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale)
|
||||
int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale)
|
||||
int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale)
|
||||
int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale)
|
||||
|
||||
int8_module.out_input_scale = torch.tensor([out_input_scale])
|
||||
|
||||
return int8_module
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
rotary_emb: Tuple[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
cos = rotary_emb[0]
|
||||
sin = rotary_emb[1]
|
||||
|
||||
int8_rotary_embedding_fwd(
|
||||
query_states.view(-1, self.num_heads, self.head_dim),
|
||||
cos,
|
||||
sin,
|
||||
self.q_output_scale.item(),
|
||||
self.q_rotary_output_scale.item(),
|
||||
)
|
||||
int8_rotary_embedding_fwd(
|
||||
key_states.view(-1, self.num_heads, self.head_dim),
|
||||
cos,
|
||||
sin,
|
||||
self.k_output_scale.item(),
|
||||
self.k_rotary_output_scale.item(),
|
||||
)
|
||||
|
||||
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
||||
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
||||
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
||||
return
|
||||
|
||||
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
# first token generation
|
||||
|
||||
# copy key and value calculated in current step to memory manager
|
||||
_copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_states,
|
||||
value_states,
|
||||
infer_state.context_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
smooth_llama_context_attn_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
self.q_rotary_output_scale.item(),
|
||||
self.k_rotary_output_scale.item(),
|
||||
self.v_output_scale.item(),
|
||||
self.out_input_scale.item(),
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
q_len,
|
||||
)
|
||||
|
||||
else:
|
||||
if infer_state.decode_is_contiguous:
|
||||
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||||
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
|
||||
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_k.copy_(key_states)
|
||||
cache_v.copy_(value_states)
|
||||
else:
|
||||
# if decode is not contiguous, use triton kernel to copy key and value cache
|
||||
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
|
||||
_copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_states,
|
||||
value_states,
|
||||
infer_state.decode_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
# (batch_size, seqlen, nheads, headdim)
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
smooth_token_attention_fwd(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
self.q_rotary_output_scale.item(),
|
||||
self.k_rotary_output_scale.item(),
|
||||
self.v_output_scale.item(),
|
||||
self.out_input_scale.item(),
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, None
|
||||
|
||||
|
||||
class LlamaLayerNormQ(torch.nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.input_scale = 1.0
|
||||
self.variance_epsilon = eps
|
||||
self.register_buffer("weight", torch.ones(dim, dtype=torch.float32))
|
||||
|
||||
def forward(self, x):
|
||||
ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon)
|
||||
ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8)
|
||||
return ln_output_int8
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.LayerNorm, output_scale: float):
|
||||
assert module.weight.shape[0] == module.weight.numel()
|
||||
q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon)
|
||||
q_module.weight = module.weight / output_scale
|
||||
return q_module
|
||||
|
||||
|
||||
class LlamaSmoothquantMLP(nn.Module):
|
||||
def __init__(self, intermediate_size, hidden_size):
|
||||
super().__init__()
|
||||
self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size)
|
||||
self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size)
|
||||
self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size)
|
||||
self.register_buffer("down_proj_input_scale", torch.tensor([1.0]))
|
||||
|
||||
@staticmethod
|
||||
def pack(
|
||||
mlp_module: LlamaMLP,
|
||||
gate_proj_input_scale: float,
|
||||
up_proj_input_scale: float,
|
||||
down_proj_input_scale: float,
|
||||
):
|
||||
int8_module = LlamaSmoothquantMLP(
|
||||
mlp_module.intermediate_size,
|
||||
mlp_module.hidden_size,
|
||||
)
|
||||
|
||||
int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale)
|
||||
int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale)
|
||||
int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale)
|
||||
int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale])
|
||||
return int8_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
x_shape = hidden_states.shape
|
||||
gate_out = self.gate_proj(hidden_states)
|
||||
up_out = self.up_proj(hidden_states)
|
||||
inter_out = gate_out * up_out
|
||||
inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8)
|
||||
down_out = self.down_proj(inter_out)
|
||||
down_out = down_out.view(*x_shape[:-1], -1)
|
||||
return down_out
|
||||
|
||||
|
||||
class LlamaSmoothquantDecoderLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads)
|
||||
|
||||
self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size)
|
||||
self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@staticmethod
|
||||
def pack(
|
||||
module: LlamaDecoderLayer,
|
||||
attn_input_scale: float,
|
||||
q_output_scale: float,
|
||||
k_output_scale: float,
|
||||
v_output_scale: float,
|
||||
q_rotary_output_scale: float,
|
||||
k_rotary_output_scale: float,
|
||||
out_input_scale: float,
|
||||
gate_input_scale: float,
|
||||
up_input_scale: float,
|
||||
down_input_scale: float,
|
||||
):
|
||||
config = module.self_attn.config
|
||||
int8_decoder_layer = LlamaSmoothquantDecoderLayer(config)
|
||||
|
||||
int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale)
|
||||
int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack(
|
||||
module.self_attn,
|
||||
attn_input_scale,
|
||||
q_output_scale,
|
||||
k_output_scale,
|
||||
v_output_scale,
|
||||
q_rotary_output_scale,
|
||||
k_rotary_output_scale,
|
||||
out_input_scale,
|
||||
)
|
||||
|
||||
int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float(
|
||||
module.post_attention_layernorm, gate_input_scale
|
||||
)
|
||||
|
||||
int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack(
|
||||
module.mlp,
|
||||
gate_input_scale,
|
||||
up_input_scale,
|
||||
down_input_scale,
|
||||
)
|
||||
|
||||
return int8_decoder_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
rotary_emb: Tuple[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
rotary_emb=rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, None, None
|
||||
|
||||
|
||||
class LlamaApplyRotary(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
|
||||
return x_embed
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states = self.q_apply_rotary(query_states, cos, sin, position_ids)
|
||||
key_states = self.k_apply_rotary(key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def init_to_get_rotary(config, base=10000, use_elem=False):
|
||||
"""
|
||||
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
|
||||
Args:
|
||||
base : calculation arg
|
||||
use_elem : activated when using chatglm-based models
|
||||
"""
|
||||
config.head_dim_ = config.hidden_size // config.num_attention_heads
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
rope_scaling_factor = 1.0
|
||||
else:
|
||||
rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0
|
||||
|
||||
if hasattr(config, "max_sequence_length"):
|
||||
max_seq_len = config.max_sequence_length
|
||||
elif hasattr(config, "max_position_embeddings"):
|
||||
max_seq_len = config.max_position_embeddings * rope_scaling_factor
|
||||
else:
|
||||
max_seq_len = 2048 * rope_scaling_factor
|
||||
base = float(base)
|
||||
|
||||
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
try:
|
||||
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1))
|
||||
assert ntk_alpha >= 1
|
||||
if ntk_alpha > 1:
|
||||
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
|
||||
max_seq_len *= ntk_alpha
|
||||
base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula
|
||||
except:
|
||||
pass
|
||||
|
||||
n_elem = config.head_dim_
|
||||
if use_elem:
|
||||
n_elem //= 2
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
|
||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
_cos_cached = torch.cos(freqs).to(torch.float)
|
||||
_sin_cached = torch.sin(freqs).to(torch.float)
|
||||
return _cos_cached, _sin_cached
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
def llama_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
infer_state = self.infer_state
|
||||
if infer_state.is_context_stage:
|
||||
past_key_values_length = 0
|
||||
else:
|
||||
past_key_values_length = infer_state.max_len_in_batch - 1
|
||||
|
||||
seq_length_with_past = seq_length + past_key_values_length
|
||||
|
||||
# NOTE: differentiate with prefill stage
|
||||
# block_loc require different value-assigning method for two different stage
|
||||
# NOTE: differentiate with prefill stage
|
||||
# block_loc require different value-assigning method for two different stage
|
||||
if infer_state.is_context_stage:
|
||||
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
|
||||
infer_state.init_block_loc(
|
||||
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
|
||||
)
|
||||
else:
|
||||
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
|
||||
if alloc_mem is not None:
|
||||
infer_state.decode_is_contiguous = True
|
||||
infer_state.decode_mem_index = alloc_mem[0]
|
||||
infer_state.decode_mem_start = alloc_mem[1]
|
||||
infer_state.decode_mem_end = alloc_mem[2]
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
else:
|
||||
print(f" *** Encountered allocation non-contiguous")
|
||||
print(f" infer_state.cache_manager.max_len_in_batch: {infer_state.max_len_in_batch}")
|
||||
infer_state.decode_is_contiguous = False
|
||||
alloc_mem = infer_state.cache_manager.alloc(batch_size)
|
||||
infer_state.decode_mem_index = alloc_mem
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
raise NotImplementedError("not implement gradient_checkpointing and training options ")
|
||||
|
||||
if past_key_values_length == 0:
|
||||
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
else:
|
||||
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
infer_state.decode_layer_id = 0
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
rotary_emb=(position_cos, position_sin),
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
infer_state.decode_layer_id += 1
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
infer_state.is_context_stage = False
|
||||
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.seq_len += 1
|
||||
infer_state.max_len_in_batch += 1
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
|
||||
layer_type = "LlamaDecoderLayer"
|
||||
|
||||
def __init__(self, model: PreTrainedModel, quantized: bool = False):
|
||||
super().__init__(model, quantized)
|
||||
|
||||
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
|
||||
def get_act_dict(
|
||||
self,
|
||||
tokenizer,
|
||||
dataset,
|
||||
num_samples=512,
|
||||
seq_len=512,
|
||||
):
|
||||
llama_model = self.model
|
||||
|
||||
llama_model.eval()
|
||||
device = next(llama_model.parameters()).device
|
||||
# print("model:", llama_model)
|
||||
act_dict = defaultdict(dict)
|
||||
|
||||
def stat_io_hook(m, x, y, name):
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
if name not in act_dict or "input" not in act_dict[name]:
|
||||
act_dict[name]["input"] = x.detach().abs().max().item()
|
||||
else:
|
||||
act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item())
|
||||
if isinstance(y, tuple):
|
||||
y = y[0]
|
||||
if name not in act_dict or "output" not in act_dict[name]:
|
||||
act_dict[name]["output"] = y.detach().abs().max().item()
|
||||
else:
|
||||
act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item())
|
||||
|
||||
for name, m in llama_model.named_modules():
|
||||
if isinstance(m, LlamaAttention):
|
||||
setattr(m, "q_apply_rotary", LlamaApplyRotary())
|
||||
setattr(m, "k_apply_rotary", LlamaApplyRotary())
|
||||
m.forward = types.MethodType(llama_decoder_layer_forward, m)
|
||||
|
||||
hooks = []
|
||||
for name, m in llama_model.named_modules():
|
||||
if isinstance(m, LlamaApplyRotary):
|
||||
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
|
||||
|
||||
self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len)
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
return act_dict
|
||||
|
||||
def smooth_fn(self, scales, alpha=0.5):
|
||||
model = self.model
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaDecoderLayer):
|
||||
attn_ln = module.input_layernorm
|
||||
qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]
|
||||
qkv_input_scales = scales[name + ".self_attn.q_proj"]
|
||||
self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
|
||||
|
||||
def create_quantized_model(model):
|
||||
llama_config = model.config
|
||||
for i, layer in enumerate(model.model.layers):
|
||||
model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config)
|
||||
|
||||
model.model.forward = types.MethodType(llama_model_forward, model.model)
|
||||
cos, sin = init_to_get_rotary(llama_config)
|
||||
model.model.register_buffer("_cos_cached", cos)
|
||||
model.model.register_buffer("_sin_cached", sin)
|
||||
|
||||
def quantized(
|
||||
self,
|
||||
tokenizer,
|
||||
dataset,
|
||||
num_samples=512,
|
||||
seq_len=512,
|
||||
alpha=0.5,
|
||||
):
|
||||
llama_model = self.model
|
||||
llama_config = llama_model.config
|
||||
|
||||
act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len)
|
||||
|
||||
self.smooth_fn(act_scales, alpha)
|
||||
|
||||
act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len)
|
||||
decoder_layer_scales = []
|
||||
|
||||
for idx in range(llama_config.num_hidden_layers):
|
||||
scale_dict = {}
|
||||
scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127
|
||||
scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127
|
||||
scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127
|
||||
scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127
|
||||
|
||||
scale_dict["q_rotary_output_scale"] = (
|
||||
act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127
|
||||
)
|
||||
scale_dict["k_rotary_output_scale"] = (
|
||||
act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127
|
||||
)
|
||||
|
||||
scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127
|
||||
|
||||
scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127
|
||||
scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127
|
||||
scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127
|
||||
|
||||
decoder_layer_scales.append(scale_dict)
|
||||
|
||||
for i, layer in enumerate(llama_model.model.layers):
|
||||
orig_layer = layer
|
||||
llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i])
|
||||
|
||||
llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model)
|
||||
|
||||
cos, sin = init_to_get_rotary(llama_config)
|
||||
llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device))
|
||||
llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device))
|
@ -0,0 +1,118 @@
|
||||
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
from .kvcache_manager import MemoryManager
|
||||
|
||||
|
||||
# adapted from: lightllm/server/router/model_infer/infer_batch.py
|
||||
@dataclass
|
||||
class BatchInferState:
|
||||
r"""
|
||||
Information to be passed and used for a batch of inputs during
|
||||
a single model forward
|
||||
"""
|
||||
batch_size: int
|
||||
max_len_in_batch: int
|
||||
|
||||
cache_manager: MemoryManager = None
|
||||
|
||||
block_loc: torch.Tensor = None
|
||||
start_loc: torch.Tensor = None
|
||||
seq_len: torch.Tensor = None
|
||||
past_key_values_len: int = None
|
||||
|
||||
is_context_stage: bool = False
|
||||
context_mem_index: torch.Tensor = None
|
||||
decode_is_contiguous: bool = None
|
||||
decode_mem_start: int = None
|
||||
decode_mem_end: int = None
|
||||
decode_mem_index: torch.Tensor = None
|
||||
decode_layer_id: int = None
|
||||
|
||||
device: torch.device = torch.device("cuda")
|
||||
|
||||
@property
|
||||
def total_token_num(self):
|
||||
# return self.batch_size * self.max_len_in_batch
|
||||
assert self.seq_len is not None and self.seq_len.size(0) > 0
|
||||
return int(torch.sum(self.seq_len))
|
||||
|
||||
def set_cache_manager(self, manager: MemoryManager):
|
||||
self.cache_manager = manager
|
||||
|
||||
# adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
|
||||
@staticmethod
|
||||
def init_block_loc(
|
||||
b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
|
||||
):
|
||||
"""in-place update block loc mapping based on the sequence length of the inputs in current bath"""
|
||||
start_index = 0
|
||||
seq_len_numpy = seq_len.cpu().numpy()
|
||||
for i, cur_seq_len in enumerate(seq_len_numpy):
|
||||
b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
|
||||
start_index : start_index + cur_seq_len
|
||||
]
|
||||
start_index += cur_seq_len
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def init_from_batch(
|
||||
cls,
|
||||
batch: torch.Tensor,
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
):
|
||||
if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)):
|
||||
raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state")
|
||||
|
||||
input_ids_list = None
|
||||
attention_mask = None
|
||||
|
||||
if isinstance(batch, (BatchEncoding, dict)):
|
||||
input_ids_list = batch["input_ids"]
|
||||
attention_mask = batch["attention_mask"]
|
||||
else:
|
||||
input_ids_list = batch
|
||||
if isinstance(input_ids_list[0], int): # for a single input
|
||||
input_ids_list = [input_ids_list]
|
||||
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
|
||||
|
||||
batch_size = len(input_ids_list)
|
||||
|
||||
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
start_index = 0
|
||||
|
||||
max_len_in_batch = -1
|
||||
if isinstance(batch, (BatchEncoding, dict)):
|
||||
for i, attn_mask in enumerate(attention_mask):
|
||||
curr_seq_len = len(attn_mask)
|
||||
seq_lengths[i] = curr_seq_len
|
||||
seq_start_indexes[i] = start_index
|
||||
start_index += curr_seq_len
|
||||
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
|
||||
else:
|
||||
length = max(len(input_id) for input_id in input_ids_list)
|
||||
for i, input_ids in enumerate(input_ids_list):
|
||||
curr_seq_len = length
|
||||
seq_lengths[i] = curr_seq_len
|
||||
seq_start_indexes[i] = start_index
|
||||
start_index += curr_seq_len
|
||||
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
|
||||
block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda")
|
||||
|
||||
return cls(
|
||||
batch_size=batch_size,
|
||||
max_len_in_batch=max_len_in_batch,
|
||||
seq_len=seq_lengths.to("cuda"),
|
||||
start_loc=seq_start_indexes.to("cuda"),
|
||||
block_loc=block_loc,
|
||||
decode_layer_id=0,
|
||||
past_key_values_len=0,
|
||||
is_context_stage=True,
|
||||
cache_manager=cache_manager,
|
||||
)
|
@ -0,0 +1,106 @@
|
||||
"""
|
||||
Refered/Modified from lightllm/common/mem_manager.py
|
||||
of the ModelTC/lightllm GitHub repository
|
||||
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
||||
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
|
||||
"""
|
||||
import torch
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
r"""
|
||||
Manage token block indexes and allocate physical memory for key and value cache
|
||||
|
||||
Args:
|
||||
size: maximum token number used as the size of key and value buffer
|
||||
dtype: data type of cached key and value
|
||||
head_num: number of heads the memory manager is responsible for
|
||||
head_dim: embedded size per head
|
||||
layer_num: the number of layers in the model
|
||||
device: device used to store the key and value cache
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
layer_num: int,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
):
|
||||
self.logger = logging.get_logger(__name__)
|
||||
self.available_size = size
|
||||
self.max_len_in_batch = 0
|
||||
self._init_mem_states(size, device)
|
||||
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
|
||||
|
||||
def _init_mem_states(self, size, device):
|
||||
"""Initialize tensors used to manage memory states"""
|
||||
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
|
||||
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
|
||||
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
|
||||
|
||||
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
|
||||
"""Initialize key buffer and value buffer on specified device"""
|
||||
self.key_buffer = [
|
||||
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
|
||||
]
|
||||
self.value_buffer = [
|
||||
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def alloc(self, required_size):
|
||||
"""allocate space of required_size by providing indexes representing available physical spaces"""
|
||||
if required_size > self.available_size:
|
||||
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
|
||||
return None
|
||||
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
|
||||
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
|
||||
select_index = self.indexes[select_index]
|
||||
self.mem_state[select_index] = 0
|
||||
self.available_size -= len(select_index)
|
||||
return select_index
|
||||
|
||||
@torch.no_grad()
|
||||
def alloc_contiguous(self, required_size):
|
||||
"""allocate contiguous space of required_size"""
|
||||
if required_size > self.available_size:
|
||||
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
|
||||
return None
|
||||
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
|
||||
sum_size = len(self.mem_cum_sum)
|
||||
loc_sums = (
|
||||
self.mem_cum_sum[required_size - 1 :]
|
||||
- self.mem_cum_sum[0 : sum_size - required_size + 1]
|
||||
+ self.mem_state[0 : sum_size - required_size + 1]
|
||||
)
|
||||
can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
|
||||
if can_used_loc.shape[0] == 0:
|
||||
self.logger.info(
|
||||
f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
|
||||
)
|
||||
return None
|
||||
start_loc = can_used_loc[0]
|
||||
select_index = self.indexes[start_loc : start_loc + required_size]
|
||||
self.mem_state[select_index] = 0
|
||||
self.available_size -= len(select_index)
|
||||
start = start_loc.item()
|
||||
end = start + required_size
|
||||
return select_index, start, end
|
||||
|
||||
@torch.no_grad()
|
||||
def free(self, free_index):
|
||||
"""free memory by updating memory states based on given indexes"""
|
||||
self.available_size += free_index.shape[0]
|
||||
self.mem_state[free_index] = 1
|
||||
|
||||
@torch.no_grad()
|
||||
def free_all(self):
|
||||
"""free all memory by updating memory states"""
|
||||
self.available_size = len(self.mem_state)
|
||||
self.mem_state[:] = 1
|
||||
self.max_len_in_batch = 0
|
||||
self.logger.info("freed all space of memory manager")
|
@ -0,0 +1,67 @@
|
||||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
|
||||
|
||||
def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
||||
"""
|
||||
This function copies the key and value cache to the memory cache
|
||||
Args:
|
||||
layer_id : id of current layer
|
||||
key_buffer : key cache
|
||||
value_buffer : value cache
|
||||
context_mem_index : index of memory cache in kv cache manager
|
||||
mem_manager : cache manager
|
||||
"""
|
||||
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
||||
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
||||
|
||||
|
||||
def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||
"""
|
||||
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
|
||||
Args:
|
||||
self : Model that holds the rotary positional embedding
|
||||
base : calculation arg
|
||||
use_elem : activated when using chatglm-based models
|
||||
"""
|
||||
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
||||
if not hasattr(self.config, "rope_scaling"):
|
||||
rope_scaling_factor = 1.0
|
||||
else:
|
||||
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
||||
|
||||
if hasattr(self.config, "max_sequence_length"):
|
||||
max_seq_len = self.config.max_sequence_length
|
||||
elif hasattr(self.config, "max_position_embeddings"):
|
||||
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
||||
else:
|
||||
max_seq_len = 2048 * rope_scaling_factor
|
||||
base = float(base)
|
||||
|
||||
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
|
||||
|
||||
if ntk_alpha is not None:
|
||||
ntk_alpha = float(ntk_alpha)
|
||||
assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
|
||||
if ntk_alpha > 1:
|
||||
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
|
||||
max_seq_len *= ntk_alpha
|
||||
base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
|
||||
|
||||
n_elem = self.config.head_dim_
|
||||
if use_elem:
|
||||
n_elem //= 2
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
|
||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
@ -0,0 +1,134 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import PPInferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
|
||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
|
||||
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
|
||||
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
for k, v in data.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = batch_size
|
||||
data[k] = v.to("cuda").repeat(*new_shape)
|
||||
return data
|
||||
|
||||
|
||||
def print_details_info(timestamps, model_config, args, whole_end2end):
|
||||
if dist.get_rank() == 0:
|
||||
prefill = []
|
||||
encoder = []
|
||||
end2end = []
|
||||
for timestamp in timestamps:
|
||||
prefill.append(timestamp[1] - timestamp[0])
|
||||
encoder.append(
|
||||
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
|
||||
)
|
||||
end2end.append(timestamp[-1] - timestamp[0])
|
||||
print(whole_end2end)
|
||||
with open(
|
||||
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
|
||||
"w+",
|
||||
) as f:
|
||||
mb_avg_end2end = sum(end2end) / len(end2end)
|
||||
mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size)
|
||||
whole_avg_latency = whole_end2end / (args.new_length * args.batch_size)
|
||||
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
||||
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
||||
if args.dtype in ["fp16", "bf16"]:
|
||||
num_bytes = 2
|
||||
else:
|
||||
num_bytes = 4
|
||||
|
||||
f.write(
|
||||
f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n"
|
||||
)
|
||||
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000))
|
||||
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000))
|
||||
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000))
|
||||
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
|
||||
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000))
|
||||
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
|
||||
f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000))))
|
||||
f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12))
|
||||
f.write("----------------------------------------------------------\n")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
current_device = torch.cuda.current_device()
|
||||
|
||||
# free memory and the total available memory in bytes
|
||||
global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info()
|
||||
memory_allocated = torch.cuda.memory_allocated()
|
||||
max_memory_allocated = torch.cuda.max_memory_allocated()
|
||||
memory_reserved = torch.cuda.memory_reserved()
|
||||
max_memory_reserved = torch.cuda.max_memory_reserved()
|
||||
with open(
|
||||
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
|
||||
"a",
|
||||
) as f:
|
||||
f.write(
|
||||
f"\nCurrently using GPU: {current_device}\n"
|
||||
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
|
||||
f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n"
|
||||
f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n"
|
||||
f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n"
|
||||
f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n"
|
||||
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default="toy", help="the size of model")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
||||
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
|
||||
parser.add_argument("--new_length", type=int, default=4, help="new tokens length")
|
||||
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="pipeline size")
|
||||
parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log")
|
||||
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model == "toy":
|
||||
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
|
||||
elif args.model == "7b":
|
||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf"))
|
||||
elif args.model == "13b":
|
||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf"))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
engine = PPInferEngine(
|
||||
pp_size=args.pp_size,
|
||||
dtype=args.dtype,
|
||||
micro_batch_size=args.mb_size,
|
||||
new_length=args.new_length,
|
||||
model=model,
|
||||
model_policy=LlamaModelInferPolicy(),
|
||||
verbose=True,
|
||||
max_batch_size=args.mb_size,
|
||||
max_input_len=args.seq_len,
|
||||
max_output_len=args.seq_len + args.new_length + 256,
|
||||
)
|
||||
data = data_gen(args.batch_size, args.seq_len)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
whole_end2end = time.time()
|
||||
output, timestamps = engine.inference([data])
|
||||
torch.cuda.synchronize()
|
||||
whole_end2end = time.time() - whole_end2end
|
||||
|
||||
print_details_info(timestamps, model.config, args, whole_end2end)
|
@ -0,0 +1,50 @@
|
||||
script_dir=$(cd "$(dirname "$0")" && pwd)
|
||||
cd "${script_dir}"
|
||||
|
||||
# 7b, fp16, 2 gpu, 1024, 128
|
||||
for BATCH_SIZE in 2 4 8 16; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="7b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=1024 \
|
||||
--new_length=128 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
||||
|
||||
# 7b, fp16, 2 gpu, 512, 512
|
||||
for BATCH_SIZE in 2 4 8 16 32; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="7b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=512 \
|
||||
--new_length=512 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
||||
|
||||
# 7b, fp16, 2 gpu, 1024, 128
|
||||
for BATCH_SIZE in 2 4 8; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="13b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=1024 \
|
||||
--new_length=128 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
||||
|
||||
# 13b, fp16, 2 gpu, 512, 512
|
||||
for BATCH_SIZE in 2 4 8 16; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="13b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=512 \
|
||||
--new_length=512 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
@ -1,70 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import BloomForCausalLM
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
import lightllm
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
TP_SIZE = 2
|
||||
MAX_BATCH_SIZE = 4
|
||||
MAX_INPUT_LEN = 16
|
||||
MAX_OUTPUT_LEN = 32
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": TP_SIZE,
|
||||
}
|
||||
],
|
||||
)
|
||||
def run(test_config):
|
||||
bloom_config = BloomConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
|
||||
model = BloomForCausalLM(bloom_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
"attention_mask": torch.ones((MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
}
|
||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
|
||||
def check_bloom(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bloom_infer():
|
||||
spawn(check_bloom, TP_SIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bloom_infer()
|
@ -1,83 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
import lightllm # noqa
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
TPSIZE = 2
|
||||
BATCH_SIZE = 8
|
||||
MAX_INPUT_LEN = 12
|
||||
MAX_OUTPUT_LEN = 100
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": TPSIZE,
|
||||
}
|
||||
],
|
||||
)
|
||||
def run_chatglm2_test(test_config):
|
||||
chatglm_config = ChatGLMConfig(
|
||||
num_layers=2,
|
||||
vocab_size=1200,
|
||||
use_cache=True,
|
||||
multi_query_attention=True,
|
||||
multi_query_group_num=2,
|
||||
num_attention_heads=8,
|
||||
hidden_size=1024,
|
||||
)
|
||||
model = ChatGLMForConditionalGeneration(chatglm_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
}
|
||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||
assert outputs is not None
|
||||
|
||||
|
||||
def check_chatglm2(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_chatglm2_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||
)
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_chatglm2():
|
||||
spawn(check_chatglm2, TPSIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chatglm2()
|
@ -1,14 +0,0 @@
|
||||
engine_config:
|
||||
model: MODEL_PATH
|
||||
tensor_parallel_size: 1
|
||||
max_batch_size: 2
|
||||
max_input_len: 1024
|
||||
max_output_len: 512
|
||||
# config for app router deployment
|
||||
# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig.
|
||||
router_config:
|
||||
max_total_token_num: 4096
|
||||
batch_max_tokens: 4096
|
||||
disable_log_stats: False
|
||||
log_stats_interval: 10
|
||||
model: MODEL_PATH
|
@ -1,61 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.async_engine import Async_Engine
|
||||
from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
PATH = "config.yaml"
|
||||
|
||||
|
||||
def run_async_engine(path: str):
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
|
||||
config = RayInitConfig.from_yaml_path(path)
|
||||
engine_config = config.engine_config_data
|
||||
model = engine_config.model
|
||||
if model is None or not os.path.exists(model):
|
||||
return
|
||||
|
||||
prompt = "Introduce some landmarks in London.\n The Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10"
|
||||
sampling_params = SamplingParams()
|
||||
asyncio.run(asy_for_loop_test(config, prompt, sampling_params))
|
||||
|
||||
|
||||
async def get_result(engine, prompt, sampling_params):
|
||||
request_id = str(uuid.uuid4().hex)
|
||||
results = engine.generate(request_id, prompt, sampling_params)
|
||||
async for result in results:
|
||||
# print(result)
|
||||
assert result is not None
|
||||
|
||||
|
||||
async def asy_for_loop_test(config, prompt, sampling_params):
|
||||
router_config = config.router_config_data
|
||||
engine_config = config.engine_config_data
|
||||
engine = Async_Engine(router_config=router_config, engine_config=engine_config)
|
||||
for i in range(10):
|
||||
print("in for loop", i)
|
||||
await get_result(engine, prompt, sampling_params)
|
||||
|
||||
|
||||
def check_async_engine(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_async_engine(PATH)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_async_engine():
|
||||
spawn(check_async_engine, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_async_engine()
|
@ -1,95 +0,0 @@
|
||||
import pytest
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.dynamic_batching.io_struct import Req
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.inference.manager import DynamicBatchManager
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
TP_SIZE = 1
|
||||
BATCH_SIZE = 2
|
||||
MAX_INPUT_LEN = 48
|
||||
MAX_OUTPUT_LEN = 256
|
||||
|
||||
|
||||
def run():
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
req1 = Req(0, [1], sampling_params)
|
||||
req2 = Req(1, [2], sampling_params)
|
||||
req3 = Req(2, [3], sampling_params)
|
||||
# req 1-3 are initiliazed as token forward requests
|
||||
req4 = Req(3, [10, 10, 10, 9, 1], sampling_params)
|
||||
waiting_list = []
|
||||
waiting_list.append(req1)
|
||||
waiting_list.append(req2)
|
||||
waiting_list.append(req3)
|
||||
|
||||
# init model and tp engine
|
||||
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
dynamic_batch_manager = DynamicBatchManager(
|
||||
tp_engine=infer_engine,
|
||||
max_total_token_num=640,
|
||||
batch_max_tokens=608,
|
||||
eos_id=0,
|
||||
log_stats=False,
|
||||
log_stats_interval=10,
|
||||
waiting_req_list=waiting_list,
|
||||
model="llama",
|
||||
)
|
||||
before_add = len(dynamic_batch_manager.req_queue)
|
||||
|
||||
# test add req function
|
||||
dynamic_batch_manager.add_req(req4.request_id, req4.prompt_ids, req4.sample_params)
|
||||
assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1
|
||||
|
||||
# test abort function
|
||||
dynamic_batch_manager.abort(req4.request_id)
|
||||
assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True
|
||||
|
||||
# test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested
|
||||
batch = dynamic_batch_manager.req_queue.generate_new_batch()
|
||||
assert len(batch) == 2
|
||||
|
||||
dynamic_batch_manager._init_batch(batch)
|
||||
assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None
|
||||
|
||||
batch.reqs[0].has_generate_finished = True
|
||||
# filter one finished
|
||||
batch.filter_finished()
|
||||
dynamic_batch_manager._filter_batch(batch)
|
||||
assert len(dynamic_batch_manager.engine.cache) == 1
|
||||
|
||||
# test merge batch
|
||||
new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch)
|
||||
assert len(new_batch) == 1
|
||||
dynamic_batch_manager._init_batch(new_batch)
|
||||
dynamic_batch_manager._merge_batch(batch, new_batch)
|
||||
|
||||
assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2
|
||||
|
||||
|
||||
def check_dynamic_batching_manager(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_dynamic_batching_manager():
|
||||
spawn(check_dynamic_batching_manager, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dynamic_batching_manager()
|
@ -1,84 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.dynamic_batching.io_struct import Req
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.inference.manager import start_dynamic_batching
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
TP_SIZE = 1
|
||||
MAX_BATCH_SIZE = 2
|
||||
MAX_INPUT_LEN = 5
|
||||
MAX_OUTPUT_LEN = 16
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
@dataclass
|
||||
class args:
|
||||
max_total_token_num: int
|
||||
batch_max_tokens: int
|
||||
model: str
|
||||
eos_id: int
|
||||
disable_log_stats: bool
|
||||
log_stats_interval: int
|
||||
|
||||
|
||||
def run():
|
||||
arg = args(
|
||||
max_total_token_num=42,
|
||||
model="llama",
|
||||
batch_max_tokens=42,
|
||||
eos_id=0,
|
||||
disable_log_stats=False,
|
||||
log_stats_interval=10,
|
||||
)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
|
||||
req2 = Req(1, [10, 10, 10, 10, 10], sampling_params)
|
||||
req3 = Req(2, [0, 0, 10, 10, 10], sampling_params)
|
||||
req4 = Req(3, [0, 0, 10, 10, 10], sampling_params)
|
||||
|
||||
waiting_list = []
|
||||
waiting_list.append(req1)
|
||||
waiting_list.append(req2)
|
||||
waiting_list.append(req3)
|
||||
waiting_list.append(req4)
|
||||
|
||||
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
|
||||
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
||||
|
||||
ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params)
|
||||
for result in ans_gen:
|
||||
assert result is not None
|
||||
|
||||
|
||||
def check_dynamic_forward(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_dynamic_batching():
|
||||
spawn(check_dynamic_forward, TP_SIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dynamic_batching()
|
@ -1,66 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.dynamic_batching.ray_dist_init import Driver
|
||||
from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
PATH = "config.yaml"
|
||||
|
||||
|
||||
def run_ray_dist(path: str):
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
config = RayInitConfig.from_yaml_path(path)
|
||||
router_config = config.router_config_data
|
||||
engine_config = config.engine_config_data
|
||||
model = engine_config.model
|
||||
if model is None or not os.path.exists(model):
|
||||
return
|
||||
driver = Driver(router_config=router_config, engine_config=engine_config)
|
||||
prompt = "Introduce some landmarks in Beijing"
|
||||
|
||||
request_id = str(uuid.uuid4().hex)
|
||||
sampling_params = SamplingParams()
|
||||
print("sampling_params: ", sampling_params)
|
||||
|
||||
async def get_result(request_id, prompt, sampling_params):
|
||||
return await driver.async_generate(request_id, prompt, sampling_params)
|
||||
|
||||
for test_async in [True, False]:
|
||||
if test_async:
|
||||
print("test_async: ", test_async)
|
||||
result = asyncio.run(get_result(request_id, prompt, sampling_params))
|
||||
assert result is not None
|
||||
print("result: ", result)
|
||||
else:
|
||||
print("test_async: ", test_async)
|
||||
result = driver.generate(request_id, prompt, sampling_params)
|
||||
assert result is not None
|
||||
print("result: ", result)
|
||||
|
||||
is_running = None
|
||||
is_running = driver.is_running()
|
||||
assert is_running is not None
|
||||
print("is_running: ", is_running)
|
||||
|
||||
|
||||
def check_ray_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_ray_dist(PATH)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_ray_dist():
|
||||
spawn(check_ray_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ray_dist()
|
@ -1,102 +0,0 @@
|
||||
from itertools import accumulate
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import BloomConfig, BloomForCausalLM
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
TP_SIZE = 2
|
||||
MAX_BATCH_SIZE = 4
|
||||
MAX_INPUT_LEN = 16
|
||||
MAX_OUTPUT_LEN = 8
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": TP_SIZE,
|
||||
}
|
||||
],
|
||||
)
|
||||
def run(test_config):
|
||||
model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
|
||||
model = BloomForCausalLM(model_config)
|
||||
model = model.half()
|
||||
model.to(torch.cuda.current_device())
|
||||
|
||||
# 1. check TPInferEngine init and model optimization
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
assert infer_engine.cache_manager is not None
|
||||
assert infer_engine.tp_size == TP_SIZE
|
||||
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE
|
||||
|
||||
# 2. check data preparation
|
||||
input_ids_list = [
|
||||
[80540, 15473, 3331, 11970, 90472, 361, 61335],
|
||||
[80540, 15473, 3331, 11970],
|
||||
[80540, 15473, 3331, 11970],
|
||||
[80540, 15473],
|
||||
]
|
||||
batch_size = len(input_ids_list)
|
||||
max_seq_len = max(len(li) for li in input_ids_list)
|
||||
attention_mask = [[0] * max_seq_len for _ in range(batch_size)]
|
||||
for i, li in enumerate(input_ids_list):
|
||||
attention_mask[i][max_seq_len - len(li) :] = [1 for _ in range(len(li))]
|
||||
data = dict(input_ids=input_ids_list, attention_mask=attention_mask)
|
||||
inputs_batch_encoding = BatchEncoding(data=data)
|
||||
seq_lengths = [len(li) for li in input_ids_list]
|
||||
start_loc = list(accumulate([0] + seq_lengths[:-1]))
|
||||
seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32)
|
||||
start_loc = torch.tensor(start_loc, dtype=torch.int32)
|
||||
# input token id list as inputs
|
||||
batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding)
|
||||
# BatchEncoding as inputs
|
||||
batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list)
|
||||
|
||||
assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size
|
||||
assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len)
|
||||
|
||||
# The following tests are discarded for now, and will be reused after all features are added
|
||||
# assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
|
||||
# assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
|
||||
# assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
|
||||
# assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)
|
||||
|
||||
# 3. check optimized model generate
|
||||
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))
|
||||
generate_kwargs = dict(do_sample=False)
|
||||
infer_engine.generate(input_ids, **generate_kwargs)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_engine(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_engine():
|
||||
spawn(check_engine, TP_SIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_engine()
|
@ -1,75 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
import lightllm
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
TPSIZE = 2
|
||||
BATCH_SIZE = 8
|
||||
MAX_INPUT_LEN = 12
|
||||
MAX_OUTPUT_LEN = 100
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": TPSIZE,
|
||||
}
|
||||
],
|
||||
)
|
||||
def run_llama_test(test_config):
|
||||
llama_config = LlamaConfig(
|
||||
num_hidden_layers=2, num_key_value_heads=8, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024
|
||||
)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
}
|
||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, TPSIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama()
|
@ -1,73 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
import lightllm
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
TPSIZE = 2
|
||||
BATCH_SIZE = 8
|
||||
MAX_INPUT_LEN = 12
|
||||
MAX_OUTPUT_LEN = 100
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": TPSIZE,
|
||||
}
|
||||
],
|
||||
)
|
||||
def run_llama_test(test_config):
|
||||
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
}
|
||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, TPSIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama()
|
Loading…
Reference in new issue