mirror of https://github.com/hpcaitech/ColossalAI
[inference]fix import bug and delete down useless init (#4830)
* fix import bug and release useless init * fix * fix * fixpull/4838/head^2
parent
573f270537
commit
013a4bedf0
|
@ -1,5 +1,3 @@
|
||||||
import _utils
|
|
||||||
|
|
||||||
from .bloom import BloomInferenceForwards
|
from .bloom import BloomInferenceForwards
|
||||||
from .chatglm2 import ChatGLM2InferenceForwards
|
from .chatglm2 import ChatGLM2InferenceForwards
|
||||||
from .llama import LlamaInferenceForwards
|
from .llama import LlamaInferenceForwards
|
||||||
|
|
|
@ -1,10 +1,67 @@
|
||||||
"""
|
"""
|
||||||
Utils for model inference
|
Utils for model inference
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||||
|
|
||||||
|
|
||||||
def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
||||||
|
"""
|
||||||
|
This function copies the key and value cache to the memory cache
|
||||||
|
Args:
|
||||||
|
layer_id : id of current layer
|
||||||
|
key_buffer : key cache
|
||||||
|
value_buffer : value cache
|
||||||
|
context_mem_index : index of memory cache in kv cache manager
|
||||||
|
mem_manager : cache manager
|
||||||
|
"""
|
||||||
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
||||||
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
||||||
return
|
|
||||||
|
|
||||||
|
def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||||
|
"""
|
||||||
|
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
|
||||||
|
Args:
|
||||||
|
self : Model that holds the rotary positional embedding
|
||||||
|
base : calculation arg
|
||||||
|
use_elem : activated when using chatglm-based models
|
||||||
|
"""
|
||||||
|
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
if not hasattr(self.config, "rope_scaling"):
|
||||||
|
rope_scaling_factor = 1.0
|
||||||
|
else:
|
||||||
|
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
||||||
|
|
||||||
|
if hasattr(self.config, "max_sequence_length"):
|
||||||
|
max_seq_len = self.config.max_sequence_length
|
||||||
|
elif hasattr(self.config, "max_position_embeddings"):
|
||||||
|
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
||||||
|
else:
|
||||||
|
max_seq_len = 2048 * rope_scaling_factor
|
||||||
|
base = float(base)
|
||||||
|
|
||||||
|
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None))
|
||||||
|
|
||||||
|
if ntk_alpha is not None:
|
||||||
|
ntk_alpha = float(ntk_alpha)
|
||||||
|
assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
|
||||||
|
if ntk_alpha > 1:
|
||||||
|
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
|
||||||
|
max_seq_len *= ntk_alpha
|
||||||
|
base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
|
||||||
|
|
||||||
|
n_elem = self.config.head_dim_
|
||||||
|
if use_elem:
|
||||||
|
n_elem //= 2
|
||||||
|
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
|
||||||
|
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||||
|
freqs = torch.outer(t, inv_freq)
|
||||||
|
|
||||||
|
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
||||||
|
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
||||||
|
|
|
@ -5,12 +5,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||||
|
|
||||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
from colossalai.kernel.triton import (
|
from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd
|
||||||
copy_kv_cache_to_dest,
|
|
||||||
llama_context_attn_fwd,
|
from ._utils import copy_kv_to_mem_cache
|
||||||
rotary_embedding_fwd,
|
|
||||||
token_attention_fwd,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import layernorm_ops, pos_encoding_ops
|
from vllm import layernorm_ops, pos_encoding_ops
|
||||||
|
@ -46,12 +43,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
|
||||||
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
|
||||||
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaInferenceForwards:
|
class LlamaInferenceForwards:
|
||||||
"""
|
"""
|
||||||
This class holds forwards for llama inference.
|
This class holds forwards for llama inference.
|
||||||
|
@ -285,11 +276,6 @@ class LlamaInferenceForwards:
|
||||||
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||||
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||||
|
|
||||||
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
|
||||||
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
|
||||||
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
|
||||||
return
|
|
||||||
|
|
||||||
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
||||||
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
|
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
|
||||||
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
|
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
|
||||||
|
@ -298,7 +284,7 @@ class LlamaInferenceForwards:
|
||||||
# first token generation
|
# first token generation
|
||||||
|
|
||||||
# copy key and value calculated in current step to memory manager
|
# copy key and value calculated in current step to memory manager
|
||||||
_copy_kv_to_mem_cache(
|
copy_kv_to_mem_cache(
|
||||||
infer_state.decode_layer_id,
|
infer_state.decode_layer_id,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
|
@ -331,7 +317,7 @@ class LlamaInferenceForwards:
|
||||||
else:
|
else:
|
||||||
# if decode is not contiguous, use triton kernel to copy key and value cache
|
# if decode is not contiguous, use triton kernel to copy key and value cache
|
||||||
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
|
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
|
||||||
_copy_kv_to_mem_cache(
|
copy_kv_to_mem_cache(
|
||||||
infer_state.decode_layer_id,
|
infer_state.decode_layer_id,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||||
ChatGLMForConditionalGeneration,
|
ChatGLMForConditionalGeneration,
|
||||||
ChatGLMModel,
|
ChatGLMModel,
|
||||||
|
@ -9,13 +7,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||||
GLMTransformer,
|
GLMTransformer,
|
||||||
SelfAttention,
|
SelfAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
# import colossalai
|
# import colossalai
|
||||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
|
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
|
||||||
|
|
||||||
from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary
|
from ..modeling._utils import init_to_get_rotary
|
||||||
|
from ..modeling.chatglm2 import ChatGLM2InferenceForwards
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
|
|
||||||
HAS_TRITON_RMSNORM = True
|
HAS_TRITON_RMSNORM = True
|
||||||
except:
|
except:
|
||||||
print("you should install triton from https://github.com/openai/triton")
|
print("you should install triton from https://github.com/openai/triton")
|
||||||
|
@ -23,7 +22,6 @@ except:
|
||||||
|
|
||||||
|
|
||||||
class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -32,45 +30,44 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
||||||
self.shard_config._infer()
|
self.shard_config._infer()
|
||||||
|
|
||||||
model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
|
model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
|
||||||
method_replacement = {'forward': model_infer_forward}
|
method_replacement = {"forward": model_infer_forward}
|
||||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)
|
||||||
|
|
||||||
encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
|
encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
|
||||||
method_replacement = {'forward': encoder_infer_forward}
|
method_replacement = {"forward": encoder_infer_forward}
|
||||||
self.append_or_create_method_replacement(description=method_replacement,
|
self.append_or_create_method_replacement(
|
||||||
policy=policy,
|
description=method_replacement, policy=policy, target_key=GLMTransformer
|
||||||
target_key=GLMTransformer)
|
)
|
||||||
|
|
||||||
encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
|
encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
|
||||||
method_replacement = {'forward': encoder_layer_infer_forward}
|
method_replacement = {"forward": encoder_layer_infer_forward}
|
||||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)
|
||||||
|
|
||||||
attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
|
attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
|
||||||
method_replacement = {'forward': attn_infer_forward}
|
method_replacement = {"forward": attn_infer_forward}
|
||||||
self.append_or_create_method_replacement(description=method_replacement,
|
self.append_or_create_method_replacement(
|
||||||
policy=policy,
|
description=method_replacement, policy=policy, target_key=SelfAttention
|
||||||
target_key=SelfAttention)
|
)
|
||||||
|
|
||||||
# for rmsnorm and others, we need to check the shape
|
# for rmsnorm and others, we need to check the shape
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
_init_to_get_rotary(self.model)
|
init_to_get_rotary(self.model)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):
|
class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
|
model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
|
||||||
method_replacement = {'forward': partial(model_infer_forward)}
|
method_replacement = {"forward": partial(model_infer_forward)}
|
||||||
self.append_or_create_method_replacement(description=method_replacement,
|
self.append_or_create_method_replacement(
|
||||||
policy=policy,
|
description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration
|
||||||
target_key=ChatGLMForConditionalGeneration)
|
)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
|
|
|
@ -3,11 +3,12 @@ from functools import partial
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||||
|
|
||||||
from colossalai.shardformer.layer import VocabParallelEmbedding1D
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
||||||
# import colossalai
|
# import colossalai
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
|
from ..modeling._utils import init_to_get_rotary
|
||||||
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -50,38 +51,38 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=ColCaiQuantLinear,
|
target_module=ColCaiQuantLinear,
|
||||||
kwargs={'split_num': 1},
|
kwargs={"split_num": 1},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
target_module=ColCaiQuantLinear,
|
target_module=ColCaiQuantLinear,
|
||||||
kwargs={'split_num': 1},
|
kwargs={"split_num": 1},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.v_proj",
|
suffix="self_attn.v_proj",
|
||||||
target_module=ColCaiQuantLinear,
|
target_module=ColCaiQuantLinear,
|
||||||
kwargs={'split_num': 1},
|
kwargs={"split_num": 1},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=RowCaiQuantLinear,
|
target_module=RowCaiQuantLinear,
|
||||||
kwargs={'split_num': 1},
|
kwargs={"split_num": 1},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.gate_proj",
|
suffix="mlp.gate_proj",
|
||||||
target_module=ColCaiQuantLinear,
|
target_module=ColCaiQuantLinear,
|
||||||
kwargs={'split_num': 1},
|
kwargs={"split_num": 1},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.up_proj",
|
suffix="mlp.up_proj",
|
||||||
target_module=ColCaiQuantLinear,
|
target_module=ColCaiQuantLinear,
|
||||||
kwargs={'split_num': 1},
|
kwargs={"split_num": 1},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.down_proj",
|
suffix="mlp.down_proj",
|
||||||
target_module=RowCaiQuantLinear,
|
target_module=RowCaiQuantLinear,
|
||||||
kwargs={'split_num': 1},
|
kwargs={"split_num": 1},
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -117,3 +118,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
)
|
)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
init_to_get_rotary(self.model.model)
|
||||||
|
return self.model
|
||||||
|
|
|
@ -3,6 +3,12 @@ try:
|
||||||
|
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
HAS_TRITON = False
|
||||||
|
print("Triton is not installed. Please install Triton to use Triton kernels.")
|
||||||
|
|
||||||
|
# There may exist import error even if we have triton installed.
|
||||||
|
if HAS_TRITON:
|
||||||
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
|
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
|
||||||
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||||
from .fused_layernorm import layer_norm
|
from .fused_layernorm import layer_norm
|
||||||
|
@ -23,7 +29,3 @@ try:
|
||||||
"token_attention_fwd",
|
"token_attention_fwd",
|
||||||
"gptq_fused_linear_triton",
|
"gptq_fused_linear_triton",
|
||||||
]
|
]
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
HAS_TRITON = False
|
|
||||||
print("Triton is not installed. Please install Triton to use Triton kernels.")
|
|
||||||
|
|
|
@ -15,30 +15,6 @@ from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_us
|
||||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def init_to_get_rotary(self, base=10000):
|
|
||||||
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
|
||||||
if not hasattr(self.config, "rope_scaling"):
|
|
||||||
rope_scaling_factor = 1.0
|
|
||||||
else:
|
|
||||||
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
|
||||||
if hasattr(self.config, "max_sequence_length"):
|
|
||||||
max_seq_len = self.config.max_sequence_length
|
|
||||||
elif hasattr(self.config, "max_position_embeddings"):
|
|
||||||
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
|
||||||
else:
|
|
||||||
max_seq_len = 2048 * rope_scaling_factor
|
|
||||||
base = float(base)
|
|
||||||
inv_freq = 1.0 / (
|
|
||||||
base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)
|
|
||||||
)
|
|
||||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
|
||||||
freqs = torch.outer(t, inv_freq)
|
|
||||||
|
|
||||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
|
||||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def print_perf_stats(latency_set, config, bs, warmup=3):
|
def print_perf_stats(latency_set, config, bs, warmup=3):
|
||||||
# trim warmup queries
|
# trim warmup queries
|
||||||
latency_set = list(latency_set)
|
latency_set = list(latency_set)
|
||||||
|
@ -66,7 +42,6 @@ def run_llama_test(args):
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
|
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
|
||||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
||||||
init_to_get_rotary(model.model, base=10000)
|
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
model_config = model.config
|
model_config = model.config
|
||||||
|
|
|
@ -1,47 +1,19 @@
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
from auto_gptq import AutoGPTQForCausalLM
|
||||||
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
|
from transformers import LlamaTokenizer
|
||||||
from torch import distributed as dist
|
|
||||||
from torch.profiler import ProfilerActivity, profile, record_function
|
|
||||||
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, TextGenerationPipeline
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.gptq import CaiQuantLinear
|
|
||||||
from colossalai.gptq.gptq_tp import replace_autogptq_linear
|
|
||||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||||
|
from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.shardformer import ShardConfig
|
from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def init_to_get_rotary(self, base=10000):
|
|
||||||
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
|
||||||
if not hasattr(self.config, "rope_scaling"):
|
|
||||||
rope_scaling_factor = 1.0
|
|
||||||
else:
|
|
||||||
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
|
||||||
if hasattr(self.config, "max_sequence_length"):
|
|
||||||
max_seq_len = self.config.max_sequence_length
|
|
||||||
elif hasattr(self.config, "max_position_embeddings"):
|
|
||||||
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
|
||||||
else:
|
|
||||||
max_seq_len = 2048 * rope_scaling_factor
|
|
||||||
base = float(base)
|
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
|
|
||||||
self.config.head_dim_))
|
|
||||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
|
||||||
freqs = torch.outer(t, inv_freq)
|
|
||||||
|
|
||||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
|
||||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def print_perf_stats(latency_set, config, bs, warmup=3):
|
def print_perf_stats(latency_set, config, bs, warmup=3):
|
||||||
|
@ -74,23 +46,23 @@ def run_llama_test(args):
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
# load quantized model to the first GPU
|
# load quantized model to the first GPU
|
||||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
|
model = AutoGPTQForCausalLM.from_quantized(
|
||||||
device=torch.cuda.current_device(),
|
quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
|
||||||
inject_fused_attention=False)
|
)
|
||||||
|
|
||||||
init_to_get_rotary(model.model.model, base=10000)
|
init_to_get_rotary(model.model.model, base=10000)
|
||||||
|
|
||||||
model_config = model.config
|
model_config = model.config
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
shard_config = ShardConfig(
|
||||||
inference_only=True,
|
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
||||||
inference_gptq=True)
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||||
|
|
||||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||||
|
|
||||||
input_tokens = {
|
input_tokens = {
|
||||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
|
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
|
||||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
|
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
|
||||||
}
|
}
|
||||||
|
|
||||||
iters = 10
|
iters = 10
|
||||||
|
@ -111,7 +83,7 @@ def run_llama_test(args):
|
||||||
|
|
||||||
def check_llama(rank, world_size, port, args):
|
def check_llama(rank, world_size, port, args):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
run_llama_test(args)
|
run_llama_test(args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -123,12 +95,12 @@ def test_llama(args):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
|
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
||||||
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True)
|
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||||
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
|
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||||
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
|
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
|
||||||
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
|
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
|
||||||
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
|
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -20,30 +20,6 @@ MAX_OUTPUT_LEN = 100
|
||||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||||
|
|
||||||
|
|
||||||
def init_to_get_rotary(self, base=10000):
|
|
||||||
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
|
||||||
if not hasattr(self.config, "rope_scaling"):
|
|
||||||
rope_scaling_factor = 1.0
|
|
||||||
else:
|
|
||||||
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
|
||||||
if hasattr(self.config, "max_sequence_length"):
|
|
||||||
max_seq_len = self.config.max_sequence_length
|
|
||||||
elif hasattr(self.config, "max_position_embeddings"):
|
|
||||||
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
|
||||||
else:
|
|
||||||
max_seq_len = 2048 * rope_scaling_factor
|
|
||||||
base = float(base)
|
|
||||||
inv_freq = 1.0 / (
|
|
||||||
base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)
|
|
||||||
)
|
|
||||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
|
||||||
freqs = torch.outer(t, inv_freq)
|
|
||||||
|
|
||||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
|
||||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
|
@ -56,7 +32,6 @@ def run_llama_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm")
|
||||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||||
orig_model = model_fn()
|
orig_model = model_fn()
|
||||||
init_to_get_rotary(orig_model.model, base=10000)
|
|
||||||
orig_model = orig_model.half()
|
orig_model = orig_model.half()
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue