[inference] add reference and fix some bugs (#4937)

* add reference and fix some bugs

* update gptq init

---------

Co-authored-by: Xu Kai <xukai16@foxamil.com>
pull/4951/head
Xu Kai 2023-10-20 13:39:34 +08:00 committed by GitHub
parent b8e770c832
commit 785802e809
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 24 additions and 10 deletions

View File

@ -132,6 +132,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
mean_scale = np.mean([v["input"] for v in act_dict.values()])
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
@ -163,6 +164,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
return act_scales
# Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
@torch.no_grad()
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
if not isinstance(fcs, list):
@ -189,6 +191,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
def create_quantized_model(model):
raise NotImplementedError("Not implement create_quantized_model method")
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
def save_quantized(
self,
save_dir: str,
@ -249,6 +252,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
self.model.config.save_pretrained(save_dir)
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
def save_pretrained(
self,
save_dir: str,
@ -260,6 +264,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
@classmethod
def from_pretrained(
cls,
@ -354,6 +359,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
return cls(model, False)
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
@classmethod
def from_quantized(
cls,

View File

@ -62,6 +62,7 @@ class W8A8BFP32O32LinearSiLU(torch.nn.Module):
return int8_module
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
class W8A8B8O8Linear(torch.nn.Module):
# For qkv_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
@ -117,6 +118,7 @@ class W8A8B8O8Linear(torch.nn.Module):
return int8_module
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
class W8A8BFP32OFP32Linear(torch.nn.Module):
# For fc2 and out_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):

View File

@ -419,6 +419,7 @@ class LlamaApplyRotary(nn.Module):
return x_embed
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
def llama_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
@ -559,6 +560,7 @@ def init_to_get_rotary(config, base=10000, use_elem=False):
return _cos_cached, _sin_cached
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def llama_model_forward(
self,
@ -729,6 +731,7 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
def __init__(self, model: PreTrainedModel, quantized: bool = False):
super().__init__(model, quantized)
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
def get_act_dict(
self,
tokenizer,

View File

@ -21,6 +21,8 @@ _supported_models = [
"BloomForCausalLM",
"ChatGLMModel",
"ChatGLMForConditionalGeneration",
"LlamaGPTQForCausalLM",
"BloomGPTQForCausalLM",
]
@ -213,11 +215,14 @@ class TPInferEngine:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
model = model.model if self.shard_config.inference_gptq else model
policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy)
if self.shard_config.inference_gptq:
self._post_init_gptq_buffer(model)
self._post_init_gptq_buffer(self.model)
self.model = self.model.cuda()

View File

@ -267,6 +267,7 @@ def cai_gptq_matmul_248_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
@autotune(
configs=[
triton.Config(

View File

@ -13,10 +13,10 @@ except ImportError:
if HAS_TRITON:
"""
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
this functions are modified from https://github.com/ModelTC/lightllm
"""
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
@triton.jit
def _context_flash_attention_kernel(
Q,
@ -145,20 +145,16 @@ if HAS_TRITON:
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@torch.no_grad()
def smooth_llama_context_attn_fwd(
q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len
):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk, "context process only supports equal query, key, value length"
assert Lk == Lv, "context process only supports equal query, key, value length"
assert Lk in {16, 32, 64, 128}
BLOCK_N = 128
sm_scale = 1.0 / math.sqrt(Lk)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
@ -203,6 +199,7 @@ if HAS_TRITON:
)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_1_kernel(
Q,
@ -264,6 +261,7 @@ if HAS_TRITON:
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_1_alibi_kernel(
Q,
@ -413,6 +411,7 @@ if HAS_TRITON:
)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
@triton.jit
def _token_attn_softmax_fwd(
softmax_logics,
@ -479,6 +478,7 @@ if HAS_TRITON:
)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_2_kernel(
Prob,

View File

@ -8,7 +8,6 @@ from transformers import LlamaTokenizer
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
@ -50,8 +49,6 @@ def run_llama_test(args):
quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
)
init_to_get_rotary(model.model.model, base=10000)
model_config = model.config
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True