mirror of https://github.com/hpcaitech/ColossalAI
[refactor] refactor gptq and smoothquant llama (#5012)
* refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcudapull/5023/head^2
parent
48d0a58d10
commit
450115bd0f
|
@ -14,10 +14,7 @@ from ..tensor_parallel.kvcache_manager import MemoryManager
|
||||||
|
|
||||||
PP_AXIS, TP_AXIS = 0, 1
|
PP_AXIS, TP_AXIS = 0, 1
|
||||||
|
|
||||||
_supported_models = [
|
_supported_models = ["LlamaForCausalLM", "BloomForCausalLM", "LlamaGPTQForCausalLM", "SmoothLlamaForCausalLM"]
|
||||||
"LlamaForCausalLM",
|
|
||||||
"BloomForCausalLM",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class CaiInferEngine:
|
class CaiInferEngine:
|
||||||
|
@ -70,12 +67,21 @@ class CaiInferEngine:
|
||||||
max_batch_size: int = 4,
|
max_batch_size: int = 4,
|
||||||
max_input_len: int = 32,
|
max_input_len: int = 32,
|
||||||
max_output_len: int = 32,
|
max_output_len: int = 32,
|
||||||
|
quant: str = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
# TODO: implement early_stopping, and various gerneration options
|
# TODO: implement early_stopping, and various gerneration options
|
||||||
early_stopping: bool = False,
|
early_stopping: bool = False,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
num_beams: int = 1,
|
num_beams: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if quant == "gptq":
|
||||||
|
from ..quant.gptq import GPTQManager
|
||||||
|
|
||||||
|
self.gptq_manager = GPTQManager(model.quantize_config, max_input_len=max_input_len)
|
||||||
|
model = model.model
|
||||||
|
elif quant == "smoothquant":
|
||||||
|
model = model.model
|
||||||
|
|
||||||
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
|
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
|
||||||
assert (
|
assert (
|
||||||
tp_size * pp_size == dist.get_world_size()
|
tp_size * pp_size == dist.get_world_size()
|
||||||
|
@ -85,9 +91,14 @@ class CaiInferEngine:
|
||||||
|
|
||||||
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
|
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||||
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
|
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
|
||||||
|
assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
||||||
self.pp_size = pp_size
|
self.pp_size = pp_size
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
|
self.quant = quant
|
||||||
|
|
||||||
|
if quant == "smoothquant" and dtype != "fp32":
|
||||||
|
dtype = "fp32"
|
||||||
|
print("Warning: smoothquant only support fp32 and int8 mix precision. set dtype to fp32")
|
||||||
|
|
||||||
if dtype == "fp16":
|
if dtype == "fp16":
|
||||||
self.dtype = torch.float16
|
self.dtype = torch.float16
|
||||||
|
@ -118,6 +129,8 @@ class CaiInferEngine:
|
||||||
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
||||||
|
|
||||||
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
|
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
|
||||||
|
if quant == "gptq":
|
||||||
|
self.gptq_manager.post_init_gptq_buffer(self.model)
|
||||||
|
|
||||||
def inference(self, input_list):
|
def inference(self, input_list):
|
||||||
"""
|
"""
|
||||||
|
@ -149,6 +162,7 @@ class CaiInferEngine:
|
||||||
enable_flash_attention=False,
|
enable_flash_attention=False,
|
||||||
enable_jit_fused=False,
|
enable_jit_fused=False,
|
||||||
enable_sequence_parallelism=False,
|
enable_sequence_parallelism=False,
|
||||||
|
quant=self.quant,
|
||||||
)
|
)
|
||||||
shardformer = ShardFormer(shard_config=shardconfig)
|
shardformer = ShardFormer(shard_config=shardconfig)
|
||||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||||
|
@ -158,7 +172,7 @@ class CaiInferEngine:
|
||||||
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
|
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
|
||||||
if model.config.model_type == "llama":
|
if model.config.model_type == "llama":
|
||||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||||
head_num = model.config.num_attention_heads // self.tp_size
|
head_num = model.config.num_key_value_heads // self.tp_size
|
||||||
num_hidden_layers = (
|
num_hidden_layers = (
|
||||||
model.config.num_hidden_layers
|
model.config.num_hidden_layers
|
||||||
if hasattr(model.config, "num_hidden_layers")
|
if hasattr(model.config, "num_hidden_layers")
|
||||||
|
@ -171,5 +185,8 @@ class CaiInferEngine:
|
||||||
num_hidden_layers = model.config.n_layer
|
num_hidden_layers = model.config.n_layer
|
||||||
layer_num = num_hidden_layers // self.pp_size
|
layer_num = num_hidden_layers // self.pp_size
|
||||||
|
|
||||||
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
|
if self.quant == "smoothquant":
|
||||||
|
cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
|
||||||
|
else:
|
||||||
|
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
|
||||||
return cache_manager
|
return cache_manager
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from .bloom import BloomInferenceForwards
|
||||||
from .llama import LlamaInferenceForwards
|
from .llama import LlamaInferenceForwards
|
||||||
|
|
||||||
__all__ = ["LlamaInferenceForwards"]
|
__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards"]
|
||||||
|
|
|
@ -45,14 +45,15 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
decoder_attribute_replacement = {
|
||||||
if self.shard_config.inference_gptq:
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||||
|
// self.shard_config.tensor_parallel_size,
|
||||||
|
}
|
||||||
|
if self.shard_config.quant == "gptq":
|
||||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||||
|
|
||||||
decoder_attribute_replacement = {
|
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
||||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
||||||
}
|
|
||||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
@ -94,6 +95,55 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.shard_config.quant == "smoothquant":
|
||||||
|
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
|
||||||
|
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
|
||||||
|
ColW8A8BFP32OFP32Linear,
|
||||||
|
RowW8A8B8O8Linear,
|
||||||
|
RowW8A8BFP32O32LinearSiLU,
|
||||||
|
RowW8A8BFP32OFP32Linear,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
target_module=RowW8A8B8O8Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=RowW8A8B8O8Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=RowW8A8B8O8Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=ColW8A8BFP32OFP32Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
target_module=RowW8A8BFP32O32LinearSiLU,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
target_module=RowW8A8BFP32OFP32Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
target_module=ColW8A8BFP32OFP32Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
self.shard_config._infer()
|
self.shard_config._infer()
|
||||||
|
|
||||||
infer_forward = LlamaInferenceForwards.llama_model_forward
|
infer_forward = LlamaInferenceForwards.llama_model_forward
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from .smoothquant.models.llama import SmoothLlamaForCausalLM
|
|
@ -2,3 +2,4 @@ from .cai_gptq import HAS_AUTO_GPTQ
|
||||||
|
|
||||||
if HAS_AUTO_GPTQ:
|
if HAS_AUTO_GPTQ:
|
||||||
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
|
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
|
||||||
|
from .gptq_manager import GPTQManager
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQManager:
|
||||||
|
def __init__(self, quant_config, max_input_len: int = 1):
|
||||||
|
self.max_dq_buffer_size = 1
|
||||||
|
self.max_inner_outer_dim = 1
|
||||||
|
self.bits = quant_config.bits
|
||||||
|
self.use_act_order = quant_config.desc_act
|
||||||
|
self.max_input_len = 1
|
||||||
|
self.gptq_temp_state_buffer = None
|
||||||
|
self.gptq_temp_dq_buffer = None
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def post_init_gptq_buffer(self, model: torch.nn.Module) -> None:
|
||||||
|
from .cai_gptq import CaiQuantLinear
|
||||||
|
|
||||||
|
HAS_GPTQ_CUDA = False
|
||||||
|
try:
|
||||||
|
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||||
|
|
||||||
|
gptq_cuda = GPTQBuilder().load()
|
||||||
|
HAS_GPTQ_CUDA = True
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn("CUDA gptq is not installed")
|
||||||
|
HAS_GPTQ_CUDA = False
|
||||||
|
|
||||||
|
for name, submodule in model.named_modules():
|
||||||
|
if isinstance(submodule, CaiQuantLinear):
|
||||||
|
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
|
||||||
|
|
||||||
|
if self.use_act_order:
|
||||||
|
self.max_inner_outer_dim = max(
|
||||||
|
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
|
||||||
|
)
|
||||||
|
self.bits = submodule.bits
|
||||||
|
if not (HAS_GPTQ_CUDA and self.bits == 4):
|
||||||
|
return
|
||||||
|
|
||||||
|
max_input_len = 1
|
||||||
|
if self.use_act_order:
|
||||||
|
max_input_len = self.max_input_len
|
||||||
|
# The temp_state buffer is required to reorder X in the act-order case.
|
||||||
|
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||||
|
self.gptq_temp_state_buffer = torch.zeros(
|
||||||
|
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
self.gptq_temp_dq_buffer = torch.zeros(
|
||||||
|
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
|
||||||
|
gptq_cuda.prepare_buffers(
|
||||||
|
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
|
||||||
|
)
|
||||||
|
# Using the default from exllama repo here.
|
||||||
|
matmul_recons_thd = 8
|
||||||
|
matmul_fused_remap = False
|
||||||
|
matmul_no_half2 = False
|
||||||
|
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
|
@ -4,9 +4,7 @@ try:
|
||||||
HAS_TORCH_INT = True
|
HAS_TORCH_INT = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_TORCH_INT = False
|
HAS_TORCH_INT = False
|
||||||
raise ImportError(
|
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||||
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
|
|
||||||
)
|
|
||||||
|
|
||||||
if HAS_TORCH_INT:
|
if HAS_TORCH_INT:
|
||||||
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
|
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
|
||||||
|
|
|
@ -9,7 +9,6 @@ from functools import partial
|
||||||
from os.path import isdir, isfile, join
|
from os.path import isdir, isfile, join
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import accelerate
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -24,6 +23,15 @@ from transformers.utils.hub import PushToHubMixin, cached_file
|
||||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
|
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
|
||||||
|
|
||||||
|
try:
|
||||||
|
import accelerate
|
||||||
|
|
||||||
|
HAS_ACCELERATE = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_ACCELERATE = False
|
||||||
|
print("accelerate is not installed.")
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_MODELS = ["llama"]
|
SUPPORTED_MODELS = ["llama"]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,25 @@
|
||||||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
|
|
||||||
from torch_int.functional.quantization import quantize_per_tensor_absmax
|
try:
|
||||||
|
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
|
||||||
|
from torch_int.functional.quantization import quantize_per_tensor_absmax
|
||||||
|
|
||||||
|
HAS_TORCH_INT = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TORCH_INT = False
|
||||||
|
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||||
|
|
||||||
smoothquant_cuda = SmoothquantBuilder().load()
|
smoothquant_cuda = SmoothquantBuilder().load()
|
||||||
HAS_SMOOTHQUANT_CUDA = True
|
HAS_SMOOTHQUANT_CUDA = True
|
||||||
except ImportError:
|
except:
|
||||||
HAS_SMOOTHQUANT_CUDA = False
|
HAS_SMOOTHQUANT_CUDA = False
|
||||||
raise ImportError("CUDA smoothquant linear is not installed")
|
print("CUDA smoothquant linear is not installed")
|
||||||
|
|
||||||
|
|
||||||
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
|
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
|
||||||
|
@ -138,21 +146,23 @@ class W8A8BFP32OFP32Linear(torch.nn.Module):
|
||||||
)
|
)
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias",
|
"bias",
|
||||||
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
|
torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False),
|
||||||
)
|
)
|
||||||
self.register_buffer("a", torch.tensor(alpha))
|
self.register_buffer("a", torch.tensor(alpha))
|
||||||
|
|
||||||
def _apply(self, fn):
|
def _apply(self, fn):
|
||||||
# prevent the bias from being converted to half
|
# prevent the bias from being converted to half
|
||||||
super()._apply(fn)
|
super()._apply(fn)
|
||||||
self.bias = self.bias.to(torch.float32)
|
if self.bias is not None:
|
||||||
|
self.bias = self.bias.to(torch.float32)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
super().to(*args, **kwargs)
|
super().to(*args, **kwargs)
|
||||||
self.weight = self.weight.to(*args, **kwargs)
|
self.weight = self.weight.to(*args, **kwargs)
|
||||||
self.bias = self.bias.to(*args, **kwargs)
|
if self.bias is not None:
|
||||||
self.bias = self.bias.to(torch.float32)
|
self.bias = self.bias.to(*args, **kwargs)
|
||||||
|
self.bias = self.bias.to(torch.float32)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
|
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
@ -18,7 +17,6 @@ from transformers.models.llama.modeling_llama import (
|
||||||
LlamaDecoderLayer,
|
LlamaDecoderLayer,
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
LlamaRotaryEmbedding,
|
LlamaRotaryEmbedding,
|
||||||
repeat_kv,
|
|
||||||
rotate_half,
|
rotate_half,
|
||||||
)
|
)
|
||||||
from transformers.utils import add_start_docstrings_to_model_forward
|
from transformers.utils import add_start_docstrings_to_model_forward
|
||||||
|
@ -31,10 +29,31 @@ from colossalai.kernel.triton import (
|
||||||
smooth_token_attention_fwd,
|
smooth_token_attention_fwd,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
|
||||||
|
|
||||||
|
HAS_TORCH_INT = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TORCH_INT = False
|
||||||
|
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||||
|
|
||||||
|
|
||||||
from .base_model import BaseSmoothForCausalLM
|
from .base_model import BaseSmoothForCausalLM
|
||||||
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
class LLamaSmoothquantAttention(nn.Module):
|
class LLamaSmoothquantAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -116,7 +135,6 @@ class LLamaSmoothquantAttention(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
rotary_emb: Tuple[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
@ -131,8 +149,7 @@ class LLamaSmoothquantAttention(nn.Module):
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
cos = rotary_emb[0]
|
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||||
sin = rotary_emb[1]
|
|
||||||
|
|
||||||
int8_rotary_embedding_fwd(
|
int8_rotary_embedding_fwd(
|
||||||
query_states.view(-1, self.num_heads, self.head_dim),
|
query_states.view(-1, self.num_heads, self.head_dim),
|
||||||
|
@ -348,7 +365,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
rotary_emb: Tuple[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
@ -378,7 +394,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module):
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
rotary_emb=rotary_emb,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
|
@ -650,15 +665,15 @@ def llama_model_forward(
|
||||||
raise NotImplementedError("not implement gradient_checkpointing and training options ")
|
raise NotImplementedError("not implement gradient_checkpointing and training options ")
|
||||||
|
|
||||||
if past_key_values_length == 0:
|
if past_key_values_length == 0:
|
||||||
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||||
position_ids.view(-1).shape[0], -1
|
position_ids.view(-1).shape[0], -1
|
||||||
)
|
)
|
||||||
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||||
position_ids.view(-1).shape[0], -1
|
position_ids.view(-1).shape[0], -1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||||
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
@ -673,7 +688,6 @@ def llama_model_forward(
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
rotary_emb=(position_cos, position_sin),
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
|
|
|
@ -0,0 +1,264 @@
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.shardformer.layer import ParallelModule
|
||||||
|
|
||||||
|
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
||||||
|
|
||||||
|
|
||||||
|
def split_row_copy(smooth_linear, para_linear, tp_size=1, tp_rank=0, split_num=1):
|
||||||
|
qweights = smooth_linear.weight.split(smooth_linear.out_features // split_num, dim=0)
|
||||||
|
if smooth_linear.bias is not None:
|
||||||
|
bias = smooth_linear.bias.split(smooth_linear.out_features // split_num, dim=0)
|
||||||
|
|
||||||
|
smooth_split_out_features = para_linear.out_features // split_num
|
||||||
|
|
||||||
|
for i in range(split_num):
|
||||||
|
para_linear.weight[i * smooth_split_out_features : (i + 1) * smooth_split_out_features, :] = qweights[i][
|
||||||
|
tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features, :
|
||||||
|
]
|
||||||
|
|
||||||
|
if para_linear.bias is not None:
|
||||||
|
para_linear.bias[:, i * smooth_split_out_features : (i + 1) * smooth_split_out_features] = bias[i][
|
||||||
|
:, tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def split_column_copy(smooth_linear, para_linear, tp_rank=0, split_num=1):
|
||||||
|
qweights = smooth_linear.weight.split(smooth_linear.in_features // split_num, dim=-1)
|
||||||
|
|
||||||
|
smooth_split_in_features = para_linear.in_features // split_num
|
||||||
|
|
||||||
|
for i in range(split_num):
|
||||||
|
para_linear.weight[:, i * smooth_split_in_features : (i + 1) * smooth_split_in_features] = qweights[i][
|
||||||
|
:, tp_rank * smooth_split_in_features : (tp_rank + 1) * smooth_split_in_features
|
||||||
|
]
|
||||||
|
|
||||||
|
if smooth_linear.bias is not None:
|
||||||
|
para_linear.bias.copy_(smooth_linear.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class RowW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
|
||||||
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||||
|
super().__init__(in_features, out_features, alpha, beta)
|
||||||
|
self.process_group = None
|
||||||
|
self.tp_size = 1
|
||||||
|
self.tp_rank = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(
|
||||||
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
# get the attributes
|
||||||
|
out_features = module.out_features
|
||||||
|
|
||||||
|
# ensure only one process group is passed
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
tp_size = dist.get_world_size(process_group)
|
||||||
|
tp_rank = dist.get_rank(process_group)
|
||||||
|
|
||||||
|
if out_features < tp_size:
|
||||||
|
return module
|
||||||
|
|
||||||
|
if out_features % tp_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||||
|
)
|
||||||
|
linear_1d = RowW8A8B8O8Linear(module.in_features, module.out_features // tp_size)
|
||||||
|
linear_1d.tp_size = tp_size
|
||||||
|
linear_1d.tp_rank = tp_rank
|
||||||
|
linear_1d.process_group = process_group
|
||||||
|
linear_1d.a = module.a.clone().detach()
|
||||||
|
linear_1d.b = module.b.clone().detach()
|
||||||
|
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||||
|
return linear_1d
|
||||||
|
|
||||||
|
|
||||||
|
class ColW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
|
||||||
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||||
|
super().__init__(in_features, out_features, alpha, beta)
|
||||||
|
self.process_group = None
|
||||||
|
self.tp_size = 1
|
||||||
|
self.tp_rank = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(
|
||||||
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
# get the attributes
|
||||||
|
in_features = module.in_features
|
||||||
|
|
||||||
|
# ensure only one process group is passed
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
tp_size = dist.get_world_size(process_group)
|
||||||
|
tp_rank = dist.get_rank(process_group)
|
||||||
|
|
||||||
|
if in_features < tp_size:
|
||||||
|
return module
|
||||||
|
|
||||||
|
if in_features % tp_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||||
|
)
|
||||||
|
linear_1d = ColW8A8B8O8Linear(module.in_features // tp_size, module.out_features)
|
||||||
|
linear_1d.tp_size = tp_size
|
||||||
|
linear_1d.tp_rank = tp_rank
|
||||||
|
linear_1d.process_group = process_group
|
||||||
|
linear_1d.a = torch.tensor(module.a)
|
||||||
|
linear_1d.b = torch.tensor(module.b)
|
||||||
|
|
||||||
|
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||||
|
if linear_1d.bias is not None:
|
||||||
|
linear_1d.bias = linear_1d.bias // tp_size
|
||||||
|
|
||||||
|
return linear_1d
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x):
|
||||||
|
output = super().forward(x)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class RowW8A8BFP32O32LinearSiLU(W8A8BFP32O32LinearSiLU, ParallelModule):
|
||||||
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||||
|
super().__init__(in_features, out_features, alpha, beta)
|
||||||
|
self.process_group = None
|
||||||
|
self.tp_size = 1
|
||||||
|
self.tp_rank = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(
|
||||||
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
# get the attributes
|
||||||
|
out_features = module.out_features
|
||||||
|
|
||||||
|
# ensure only one process group is passed
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
tp_size = dist.get_world_size(process_group)
|
||||||
|
tp_rank = dist.get_rank(process_group)
|
||||||
|
|
||||||
|
if out_features < tp_size:
|
||||||
|
return module
|
||||||
|
|
||||||
|
if out_features % tp_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||||
|
)
|
||||||
|
linear_1d = RowW8A8BFP32O32LinearSiLU(module.in_features, module.out_features // tp_size)
|
||||||
|
linear_1d.tp_size = tp_size
|
||||||
|
linear_1d.tp_rank = tp_rank
|
||||||
|
linear_1d.process_group = process_group
|
||||||
|
linear_1d.a = module.a.clone().detach()
|
||||||
|
|
||||||
|
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||||
|
return linear_1d
|
||||||
|
|
||||||
|
|
||||||
|
class RowW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
|
||||||
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||||
|
super().__init__(in_features, out_features, alpha, beta)
|
||||||
|
self.process_group = None
|
||||||
|
self.tp_size = 1
|
||||||
|
self.tp_rank = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(
|
||||||
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
# get the attributes
|
||||||
|
out_features = module.out_features
|
||||||
|
|
||||||
|
# ensure only one process group is passed
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
tp_size = dist.get_world_size(process_group)
|
||||||
|
tp_rank = dist.get_rank(process_group)
|
||||||
|
|
||||||
|
if out_features < tp_size:
|
||||||
|
return module
|
||||||
|
|
||||||
|
if out_features % tp_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||||
|
)
|
||||||
|
linear_1d = RowW8A8BFP32OFP32Linear(module.in_features, module.out_features // tp_size)
|
||||||
|
linear_1d.tp_size = tp_size
|
||||||
|
linear_1d.tp_rank = tp_rank
|
||||||
|
linear_1d.process_group = process_group
|
||||||
|
linear_1d.a = module.a.clone().detach()
|
||||||
|
|
||||||
|
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||||
|
return linear_1d
|
||||||
|
|
||||||
|
|
||||||
|
class ColW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
|
||||||
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||||
|
super().__init__(in_features, out_features, alpha, beta)
|
||||||
|
self.process_group = None
|
||||||
|
self.tp_size = 1
|
||||||
|
self.tp_rank = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(
|
||||||
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
# get the attributes
|
||||||
|
in_features = module.in_features
|
||||||
|
|
||||||
|
# ensure only one process group is passed
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
tp_size = dist.get_world_size(process_group)
|
||||||
|
tp_rank = dist.get_rank(process_group)
|
||||||
|
|
||||||
|
if in_features < tp_size:
|
||||||
|
return module
|
||||||
|
|
||||||
|
if in_features % tp_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||||
|
)
|
||||||
|
linear_1d = ColW8A8BFP32OFP32Linear(module.in_features // tp_size, module.out_features)
|
||||||
|
linear_1d.tp_size = tp_size
|
||||||
|
linear_1d.tp_rank = tp_rank
|
||||||
|
linear_1d.process_group = process_group
|
||||||
|
linear_1d.a = module.a.clone().detach()
|
||||||
|
|
||||||
|
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||||
|
if linear_1d.bias is not None:
|
||||||
|
linear_1d.bias = linear_1d.bias / tp_size
|
||||||
|
|
||||||
|
return linear_1d
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x):
|
||||||
|
output = super().forward(x)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||||
|
return output
|
|
@ -37,6 +37,7 @@ class ShardConfig:
|
||||||
inference_gptq: bool = False
|
inference_gptq: bool = False
|
||||||
enable_sequence_parallelism: bool = False
|
enable_sequence_parallelism: bool = False
|
||||||
enable_sequence_overlap: bool = False
|
enable_sequence_overlap: bool = False
|
||||||
|
quant: str = None
|
||||||
# pipeline_parallel_size: int
|
# pipeline_parallel_size: int
|
||||||
# data_parallel_size: int
|
# data_parallel_size: int
|
||||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||||
|
@ -77,4 +78,3 @@ class ShardConfig:
|
||||||
Set default params for inference.
|
Set default params for inference.
|
||||||
"""
|
"""
|
||||||
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
|
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
|
||||||
pass
|
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from auto_gptq import AutoGPTQForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
def run_llama_test(args):
|
||||||
|
quantized_model_dir = args.quantized_path
|
||||||
|
max_batch_size = args.max_batch_size
|
||||||
|
max_input_len = args.max_input_len
|
||||||
|
max_output_len = args.max_output_len
|
||||||
|
micro_batch_size = args.micro_batch_size
|
||||||
|
# load quantized model to the first GPU
|
||||||
|
model = AutoGPTQForCausalLM.from_quantized(
|
||||||
|
quantized_model_dir, inject_fused_attention=False, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = CaiInferEngine(
|
||||||
|
tp_size=2,
|
||||||
|
pp_size=2,
|
||||||
|
model=model,
|
||||||
|
model_policy=LlamaModelInferPolicy(),
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
max_input_len=max_input_len,
|
||||||
|
max_output_len=max_output_len,
|
||||||
|
micro_batch_size=micro_batch_size,
|
||||||
|
quant="gptq",
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||||
|
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||||
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
inputs = data_gen()
|
||||||
|
for k, v in inputs.items():
|
||||||
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||||
|
new_shape = [1] * v.dim()
|
||||||
|
new_shape[0] = 16
|
||||||
|
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||||
|
|
||||||
|
output = engine.inference(inputs)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_llama(rank, world_size, port, args):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_llama_test(args)
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_gptq_llama(args):
|
||||||
|
spawn(check_llama, args.tp_size * args.pp_size, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||||
|
parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size")
|
||||||
|
parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size")
|
||||||
|
parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size")
|
||||||
|
parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size")
|
||||||
|
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
||||||
|
parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
test_gptq_llama(args)
|
|
@ -0,0 +1,76 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
|
||||||
|
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run_llama_test(args):
|
||||||
|
quantized_model_dir = args.quantized_path
|
||||||
|
max_batch_size = args.max_batch_size
|
||||||
|
max_input_len = args.max_input_len
|
||||||
|
max_output_len = args.max_output_len
|
||||||
|
micro_batch_size = args.micro_batch_size
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||||
|
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||||
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
inputs = data_gen()
|
||||||
|
for k, v in inputs.items():
|
||||||
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||||
|
new_shape = [1] * v.dim()
|
||||||
|
new_shape[0] = 16
|
||||||
|
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||||
|
|
||||||
|
model = SmoothLlamaForCausalLM.from_quantized(quantized_model_dir, model_basename="llama-7b")
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
|
engine = CaiInferEngine(
|
||||||
|
tp_size=2,
|
||||||
|
pp_size=2,
|
||||||
|
model=model,
|
||||||
|
model_policy=LlamaModelInferPolicy(),
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
max_input_len=max_input_len,
|
||||||
|
max_output_len=max_output_len,
|
||||||
|
micro_batch_size=micro_batch_size,
|
||||||
|
quant="smoothquant",
|
||||||
|
)
|
||||||
|
|
||||||
|
output = engine.inference(inputs)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
assert len(output[0]) == 32, f"{len(output)}, {32}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_llama(rank, world_size, port, args):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_llama_test(args)
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_smoothquant_llama():
|
||||||
|
spawn(check_llama, args.tp_size * args.pp_size, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||||
|
parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size")
|
||||||
|
parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size")
|
||||||
|
parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size")
|
||||||
|
parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size")
|
||||||
|
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
||||||
|
parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
test_smoothquant_llama()
|
|
@ -9,6 +9,10 @@ from colossalai.inference import BloomModelInferPolicy, CaiInferEngine
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||||
|
try:
|
||||||
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
except:
|
||||||
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
|
|
||||||
def data_gen():
|
def data_gen():
|
||||||
|
@ -88,7 +92,10 @@ def check_tp_inference(rank, world_size, port):
|
||||||
run_tp_inference_test()
|
run_tp_inference_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(
|
||||||
|
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||||
|
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||||
|
)
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
|
@ -9,6 +9,10 @@ from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||||
|
try:
|
||||||
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
except:
|
||||||
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
|
|
||||||
def data_gen():
|
def data_gen():
|
||||||
|
@ -90,7 +94,10 @@ def check_tp_inference(rank, world_size, port):
|
||||||
run_tp_inference_test()
|
run_tp_inference_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(
|
||||||
|
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||||
|
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||||
|
)
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
Loading…
Reference in New Issue