diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py index 279b54065..4662368b1 100644 --- a/colossalai/inference/tensor_parallel/modeling/__init__.py +++ b/colossalai/inference/tensor_parallel/modeling/__init__.py @@ -1,5 +1,3 @@ -import _utils - from .bloom import BloomInferenceForwards from .chatglm2 import ChatGLM2InferenceForwards from .llama import LlamaInferenceForwards diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py index cee418707..e476c3132 100644 --- a/colossalai/inference/tensor_parallel/modeling/_utils.py +++ b/colossalai/inference/tensor_parallel/modeling/_utils.py @@ -1,10 +1,67 @@ """ Utils for model inference """ +import os + +import torch + from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + """ + This function copies the key and value cache to the memory cache + Args: + layer_id : id of current layer + key_buffer : key cache + value_buffer : value cache + context_mem_index : index of memory cache in kv cache manager + mem_manager : cache manager + """ copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) - return + + +def init_to_get_rotary(self, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + self : Model that holds the rotary positional embedding + base : calculation arg + use_elem : activated when using chatglm-based models + """ + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None)) + + if ntk_alpha is not None: + ntk_alpha = float(ntk_alpha) + assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + + n_elem = self.config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 64d6e947e..a7661cee1 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -5,12 +5,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton import ( - copy_kv_cache_to_dest, - llama_context_attn_fwd, - rotary_embedding_fwd, - token_attention_fwd, -) +from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd + +from ._utils import copy_kv_to_mem_cache try: from vllm import layernorm_ops, pos_encoding_ops @@ -46,12 +43,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed -def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) - return - - class LlamaInferenceForwards: """ This class holds forwards for llama inference. @@ -285,11 +276,6 @@ class LlamaInferenceForwards: rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) - def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) - return - query_states = query_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_heads, self.head_dim) value_states = value_states.reshape(-1, self.num_heads, self.head_dim) @@ -298,7 +284,7 @@ class LlamaInferenceForwards: # first token generation # copy key and value calculated in current step to memory manager - _copy_kv_to_mem_cache( + copy_kv_to_mem_cache( infer_state.decode_layer_id, key_states, value_states, @@ -331,7 +317,7 @@ class LlamaInferenceForwards: else: # if decode is not contiguous, use triton kernel to copy key and value cache # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - _copy_kv_to_mem_cache( + copy_kv_to_mem_cache( infer_state.decode_layer_id, key_states, value_states, diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py index cb223370a..90f8b4fd2 100644 --- a/colossalai/inference/tensor_parallel/policies/chatglm2.py +++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py @@ -1,7 +1,5 @@ from functools import partial -import torch - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, ChatGLMModel, @@ -9,13 +7,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( GLMTransformer, SelfAttention, ) + # import colossalai from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy -from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary +from ..modeling._utils import init_to_get_rotary +from ..modeling.chatglm2 import ChatGLM2InferenceForwards try: - from colossalai.kernel.triton.rms_norm import rmsnorm_forward HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") @@ -23,7 +22,6 @@ except: class ChatGLM2InferPolicy(ChatGLMModelPolicy): - def __init__(self) -> None: super().__init__() @@ -32,45 +30,44 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy): self.shard_config._infer() model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward - method_replacement = {'forward': model_infer_forward} + method_replacement = {"forward": model_infer_forward} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel) encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward - method_replacement = {'forward': encoder_infer_forward} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=GLMTransformer) + method_replacement = {"forward": encoder_infer_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=GLMTransformer + ) encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward - method_replacement = {'forward': encoder_layer_infer_forward} + method_replacement = {"forward": encoder_layer_infer_forward} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock) attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward - method_replacement = {'forward': attn_infer_forward} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=SelfAttention) + method_replacement = {"forward": attn_infer_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=SelfAttention + ) # for rmsnorm and others, we need to check the shape return policy def postprocess(self): - _init_to_get_rotary(self.model) + init_to_get_rotary(self.model) return self.model class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward - method_replacement = {'forward': partial(model_infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=ChatGLMForConditionalGeneration) + method_replacement = {"forward": partial(model_infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration + ) return policy def postprocess(self): diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index eaaadadd1..507c1203d 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -3,11 +3,12 @@ from functools import partial import torch from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm -from colossalai.shardformer.layer import VocabParallelEmbedding1D -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription + # import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy +from ..modeling._utils import init_to_get_rotary from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward try: @@ -50,38 +51,38 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=ColCaiQuantLinear, - kwargs={'split_num': 1}, + kwargs={"split_num": 1}, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=ColCaiQuantLinear, - kwargs={'split_num': 1}, + kwargs={"split_num": 1}, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=ColCaiQuantLinear, - kwargs={'split_num': 1}, + kwargs={"split_num": 1}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=RowCaiQuantLinear, - kwargs={'split_num': 1}, + kwargs={"split_num": 1}, ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=ColCaiQuantLinear, - kwargs={'split_num': 1}, + kwargs={"split_num": 1}, ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=ColCaiQuantLinear, - kwargs={'split_num': 1}, + kwargs={"split_num": 1}, ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=RowCaiQuantLinear, - kwargs={'split_num': 1}, - ) + kwargs={"split_num": 1}, + ), ], ) @@ -117,3 +118,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): ) return policy + + def postprocess(self): + init_to_get_rotary(self.model.model) + return self.model diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 87ea9cf65..983069158 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -3,6 +3,12 @@ try: HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("Triton is not installed. Please install Triton to use Triton kernels.") + +# There may exist import error even if we have triton installed. +if HAS_TRITON: from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm @@ -23,7 +29,3 @@ try: "token_attention_fwd", "gptq_fused_linear_triton", ] - -except ImportError: - HAS_TRITON = False - print("Triton is not installed. Please install Triton to use Triton kernels.") diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 6e49fa80c..9614bdf88 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -15,30 +15,6 @@ from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_us os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -def init_to_get_rotary(self, base=10000): - self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads - if not hasattr(self.config, "rope_scaling"): - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - if hasattr(self.config, "max_sequence_length"): - max_seq_len = self.config.max_sequence_length - elif hasattr(self.config, "max_position_embeddings"): - max_seq_len = self.config.max_position_embeddings * rope_scaling_factor - else: - max_seq_len = 2048 * rope_scaling_factor - base = float(base) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_) - ) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() - return - - def print_perf_stats(latency_set, config, bs, warmup=3): # trim warmup queries latency_set = list(latency_set) @@ -66,7 +42,6 @@ def run_llama_test(args): tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) - init_to_get_rotary(model.model, base=10000) model = model.half() model_config = model.config diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index 818ae0035..1bdee448c 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -1,47 +1,19 @@ import argparse -import logging import os import time import torch -from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig -from auto_gptq.nn_modules.qlinear import GeneralQuantLinear -from torch import distributed as dist -from torch.profiler import ProfilerActivity, profile, record_function -from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, TextGenerationPipeline +from auto_gptq import AutoGPTQForCausalLM +from transformers import LlamaTokenizer import colossalai -from colossalai.gptq import CaiQuantLinear -from colossalai.gptq.gptq_tp import replace_autogptq_linear from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' - - -def init_to_get_rotary(self, base=10000): - self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads - if not hasattr(self.config, "rope_scaling"): - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - if hasattr(self.config, "max_sequence_length"): - max_seq_len = self.config.max_sequence_length - elif hasattr(self.config, "max_position_embeddings"): - max_seq_len = self.config.max_position_embeddings * rope_scaling_factor - else: - max_seq_len = 2048 * rope_scaling_factor - base = float(base) - inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / - self.config.head_dim_)) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() - return +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def print_perf_stats(latency_set, config, bs, warmup=3): @@ -74,23 +46,23 @@ def run_llama_test(args): tokenizer.pad_token_id = tokenizer.eos_token_id # load quantized model to the first GPU - model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, - device=torch.cuda.current_device(), - inject_fused_attention=False) + model = AutoGPTQForCausalLM.from_quantized( + quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False + ) init_to_get_rotary(model.model.model, base=10000) model_config = model.config - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, - inference_only=True, - inference_gptq=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { - "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), - "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } iters = 10 @@ -111,7 +83,7 @@ def run_llama_test(args): def check_llama(rank, world_size, port, args): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_test(args) @@ -123,12 +95,12 @@ def test_llama(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) - parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') - parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') - parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') - parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") args = parser.parse_args() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 0e5efe685..b260c7011 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -20,30 +20,6 @@ MAX_OUTPUT_LEN = 100 CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -def init_to_get_rotary(self, base=10000): - self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads - if not hasattr(self.config, "rope_scaling"): - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - if hasattr(self.config, "max_sequence_length"): - max_seq_len = self.config.max_sequence_length - elif hasattr(self.config, "max_position_embeddings"): - max_seq_len = self.config.max_position_embeddings * rope_scaling_factor - else: - max_seq_len = 2048 * rope_scaling_factor - base = float(base) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_) - ) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() - return - - @parameterize( "test_config", [ @@ -56,7 +32,6 @@ def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm") for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): orig_model = model_fn() - init_to_get_rotary(orig_model.model, base=10000) orig_model = orig_model.half() data = data_gen_fn()