diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 180e6c6e8..6a1d96ece 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -132,6 +132,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): mean_scale = np.mean([v["input"] for v in act_dict.values()]) pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512): model.eval() device = next(model.parameters()).device @@ -163,6 +164,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): return act_scales + # Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py @torch.no_grad() def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): if not isinstance(fcs, list): @@ -189,6 +191,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): def create_quantized_model(model): raise NotImplementedError("Not implement create_quantized_model method") + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py def save_quantized( self, save_dir: str, @@ -249,6 +252,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): self.model.config.save_pretrained(save_dir) + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py def save_pretrained( self, save_dir: str, @@ -260,6 +264,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.") self.save_quantized(save_dir, use_safetensors, safetensors_metadata) + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py @classmethod def from_pretrained( cls, @@ -354,6 +359,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): return cls(model, False) + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py @classmethod def from_quantized( cls, diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py index 048565bfb..969c390a0 100644 --- a/colossalai/inference/quant/smoothquant/models/linear.py +++ b/colossalai/inference/quant/smoothquant/models/linear.py @@ -62,6 +62,7 @@ class W8A8BFP32O32LinearSiLU(torch.nn.Module): return int8_module +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py class W8A8B8O8Linear(torch.nn.Module): # For qkv_proj def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): @@ -117,6 +118,7 @@ class W8A8B8O8Linear(torch.nn.Module): return int8_module +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py class W8A8BFP32OFP32Linear(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 9c77feeb3..4c3d6dcc0 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -419,6 +419,7 @@ class LlamaApplyRotary(nn.Module): return x_embed +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py def llama_decoder_layer_forward( self, hidden_states: torch.Tensor, @@ -559,6 +560,7 @@ def init_to_get_rotary(config, base=10000, use_elem=False): return _cos_cached, _sin_cached +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def llama_model_forward( self, @@ -729,6 +731,7 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): def __init__(self, model: PreTrainedModel, quantized: bool = False): super().__init__(model, quantized) + # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py def get_act_dict( self, tokenizer, diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index e4c4a2d70..216b134f5 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -21,6 +21,8 @@ _supported_models = [ "BloomForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration", + "LlamaGPTQForCausalLM", + "BloomGPTQForCausalLM", ] @@ -213,11 +215,14 @@ class TPInferEngine: ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." + + model = model.model if self.shard_config.inference_gptq else model + policy = get_autopolicy(model, inference_only=True) self.model, _ = shardformer.optimize(model, policy) if self.shard_config.inference_gptq: - self._post_init_gptq_buffer(model) + self._post_init_gptq_buffer(self.model) self.model = self.model.cuda() diff --git a/colossalai/kernel/triton/gptq_triton.py b/colossalai/kernel/triton/gptq_triton.py index 8460103e2..2dc1fe044 100644 --- a/colossalai/kernel/triton/gptq_triton.py +++ b/colossalai/kernel/triton/gptq_triton.py @@ -267,6 +267,7 @@ def cai_gptq_matmul_248_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ @autotune( configs=[ triton.Config( diff --git a/colossalai/kernel/triton/smooth_attention.py b/colossalai/kernel/triton/smooth_attention.py index ee0df6a74..071de58e2 100644 --- a/colossalai/kernel/triton/smooth_attention.py +++ b/colossalai/kernel/triton/smooth_attention.py @@ -13,10 +13,10 @@ except ImportError: if HAS_TRITON: """ - this function is modified from - https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + this functions are modified from https://github.com/ModelTC/lightllm """ + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py @triton.jit def _context_flash_attention_kernel( Q, @@ -145,20 +145,16 @@ if HAS_TRITON: tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return - - @torch.no_grad() def smooth_llama_context_attn_fwd( q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len ): - BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk, "context process only supports equal query, key, value length" assert Lk == Lv, "context process only supports equal query, key, value length" assert Lk in {16, 32, 64, 128} - BLOCK_N = 128 sm_scale = 1.0 / math.sqrt(Lk) batch, head = b_seq_len.shape[0], q.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) @@ -203,6 +199,7 @@ if HAS_TRITON: ) return + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @triton.jit def _token_attn_1_kernel( Q, @@ -264,6 +261,7 @@ if HAS_TRITON: tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) return + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @triton.jit def _token_attn_1_alibi_kernel( Q, @@ -413,6 +411,7 @@ if HAS_TRITON: ) return + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py @triton.jit def _token_attn_softmax_fwd( softmax_logics, @@ -479,6 +478,7 @@ if HAS_TRITON: ) return + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @triton.jit def _token_attn_2_kernel( Prob, diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index 1bdee448c..4823377d7 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -8,7 +8,6 @@ from transformers import LlamaTokenizer import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine -from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn @@ -50,8 +49,6 @@ def run_llama_test(args): quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False ) - init_to_get_rotary(model.model.model, base=10000) - model_config = model.config shard_config = ShardConfig( enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True