[refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda
pull/5023/head^2
Xu Kai 2023-11-08 09:17:52 +08:00 committed by FoolPlayer
parent 48d0a58d10
commit 450115bd0f
16 changed files with 635 additions and 41 deletions

View File

@ -14,10 +14,7 @@ from ..tensor_parallel.kvcache_manager import MemoryManager
PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
"BloomForCausalLM",
]
_supported_models = ["LlamaForCausalLM", "BloomForCausalLM", "LlamaGPTQForCausalLM", "SmoothLlamaForCausalLM"]
class CaiInferEngine:
@ -70,12 +67,21 @@ class CaiInferEngine:
max_batch_size: int = 4,
max_input_len: int = 32,
max_output_len: int = 32,
quant: str = None,
verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
do_sample: bool = False,
num_beams: int = 1,
) -> None:
if quant == "gptq":
from ..quant.gptq import GPTQManager
self.gptq_manager = GPTQManager(model.quantize_config, max_input_len=max_input_len)
model = model.model
elif quant == "smoothquant":
model = model.model
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
assert (
tp_size * pp_size == dist.get_world_size()
@ -85,9 +91,14 @@ class CaiInferEngine:
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
self.pp_size = pp_size
self.tp_size = tp_size
self.quant = quant
if quant == "smoothquant" and dtype != "fp32":
dtype = "fp32"
print("Warning: smoothquant only support fp32 and int8 mix precision. set dtype to fp32")
if dtype == "fp16":
self.dtype = torch.float16
@ -118,6 +129,8 @@ class CaiInferEngine:
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
if quant == "gptq":
self.gptq_manager.post_init_gptq_buffer(self.model)
def inference(self, input_list):
"""
@ -149,6 +162,7 @@ class CaiInferEngine:
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
quant=self.quant,
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
@ -158,7 +172,7 @@ class CaiInferEngine:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
if model.config.model_type == "llama":
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads // self.tp_size
head_num = model.config.num_key_value_heads // self.tp_size
num_hidden_layers = (
model.config.num_hidden_layers
if hasattr(model.config, "num_hidden_layers")
@ -171,5 +185,8 @@ class CaiInferEngine:
num_hidden_layers = model.config.n_layer
layer_num = num_hidden_layers // self.pp_size
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
if self.quant == "smoothquant":
cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
else:
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
return cache_manager

View File

@ -1,3 +1,4 @@
from .bloom import BloomInferenceForwards
from .llama import LlamaInferenceForwards
__all__ = ["LlamaInferenceForwards"]
__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards"]

View File

@ -45,14 +45,15 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.inference_gptq:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
if self.shard_config.quant == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
@ -94,6 +95,55 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
],
)
elif self.shard_config.quant == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
ColW8A8BFP32OFP32Linear,
RowW8A8B8O8Linear,
RowW8A8BFP32O32LinearSiLU,
RowW8A8BFP32OFP32Linear,
)
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=RowW8A8BFP32O32LinearSiLU,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=RowW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
],
)
self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward

View File

@ -0,0 +1 @@
from .smoothquant.models.llama import SmoothLlamaForCausalLM

View File

@ -2,3 +2,4 @@ from .cai_gptq import HAS_AUTO_GPTQ
if HAS_AUTO_GPTQ:
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
from .gptq_manager import GPTQManager

View File

@ -0,0 +1,61 @@
import torch
class GPTQManager:
def __init__(self, quant_config, max_input_len: int = 1):
self.max_dq_buffer_size = 1
self.max_inner_outer_dim = 1
self.bits = quant_config.bits
self.use_act_order = quant_config.desc_act
self.max_input_len = 1
self.gptq_temp_state_buffer = None
self.gptq_temp_dq_buffer = None
self.quant_config = quant_config
def post_init_gptq_buffer(self, model: torch.nn.Module) -> None:
from .cai_gptq import CaiQuantLinear
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False
for name, submodule in model.named_modules():
if isinstance(submodule, CaiQuantLinear):
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
if self.use_act_order:
self.max_inner_outer_dim = max(
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
)
self.bits = submodule.bits
if not (HAS_GPTQ_CUDA and self.bits == 4):
return
max_input_len = 1
if self.use_act_order:
max_input_len = self.max_input_len
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.gptq_temp_state_buffer = torch.zeros(
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
)
self.gptq_temp_dq_buffer = torch.zeros(
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
)
gptq_cuda.prepare_buffers(
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
torch.cuda.empty_cache()

View File

@ -4,9 +4,7 @@ try:
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
raise ImportError(
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
)
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
if HAS_TORCH_INT:
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP

View File

@ -9,7 +9,6 @@ from functools import partial
from os.path import isdir, isfile, join
from typing import Dict, List, Optional, Union
import accelerate
import numpy as np
import torch
import torch.nn as nn
@ -24,6 +23,15 @@ from transformers.utils.hub import PushToHubMixin, cached_file
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
try:
import accelerate
HAS_ACCELERATE = True
except ImportError:
HAS_ACCELERATE = False
print("accelerate is not installed.")
SUPPORTED_MODELS = ["llama"]

View File

@ -1,17 +1,25 @@
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
import torch
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
from torch_int.functional.quantization import quantize_per_tensor_absmax
try:
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
from torch_int.functional.quantization import quantize_per_tensor_absmax
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except ImportError:
except:
HAS_SMOOTHQUANT_CUDA = False
raise ImportError("CUDA smoothquant linear is not installed")
print("CUDA smoothquant linear is not installed")
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
@ -138,21 +146,23 @@ class W8A8BFP32OFP32Linear(torch.nn.Module):
)
self.register_buffer(
"bias",
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))
def _apply(self, fn):
# prevent the bias from being converted to half
super()._apply(fn)
self.bias = self.bias.to(torch.float32)
if self.bias is not None:
self.bias = self.bias.to(torch.float32)
return self
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
self.bias = self.bias.to(torch.float32)
if self.bias is not None:
self.bias = self.bias.to(*args, **kwargs)
self.bias = self.bias.to(torch.float32)
return self
@torch.no_grad()

View File

@ -8,7 +8,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
@ -18,7 +17,6 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaMLP,
LlamaRotaryEmbedding,
repeat_kv,
rotate_half,
)
from transformers.utils import add_start_docstrings_to_model_forward
@ -31,10 +29,31 @@ from colossalai.kernel.triton import (
smooth_token_attention_fwd,
)
try:
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
from .base_model import BaseSmoothForCausalLM
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class LLamaSmoothquantAttention(nn.Module):
def __init__(
self,
@ -116,7 +135,6 @@ class LLamaSmoothquantAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
rotary_emb: Tuple[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
@ -131,8 +149,7 @@ class LLamaSmoothquantAttention(nn.Module):
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
cos = rotary_emb[0]
sin = rotary_emb[1]
cos, sin = infer_state.position_cos, infer_state.position_sin
int8_rotary_embedding_fwd(
query_states.view(-1, self.num_heads, self.head_dim),
@ -348,7 +365,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
rotary_emb: Tuple[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
@ -378,7 +394,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module):
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
rotary_emb=rotary_emb,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
@ -650,15 +665,15 @@ def llama_model_forward(
raise NotImplementedError("not implement gradient_checkpointing and training options ")
if past_key_values_length == 0:
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
else:
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
# decoder layers
all_hidden_states = () if output_hidden_states else None
@ -673,7 +688,6 @@ def llama_model_forward(
layer_outputs = decoder_layer(
hidden_states,
rotary_emb=(position_cos, position_sin),
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,

View File

@ -0,0 +1,264 @@
from typing import List, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import ParallelModule
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
def split_row_copy(smooth_linear, para_linear, tp_size=1, tp_rank=0, split_num=1):
qweights = smooth_linear.weight.split(smooth_linear.out_features // split_num, dim=0)
if smooth_linear.bias is not None:
bias = smooth_linear.bias.split(smooth_linear.out_features // split_num, dim=0)
smooth_split_out_features = para_linear.out_features // split_num
for i in range(split_num):
para_linear.weight[i * smooth_split_out_features : (i + 1) * smooth_split_out_features, :] = qweights[i][
tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features, :
]
if para_linear.bias is not None:
para_linear.bias[:, i * smooth_split_out_features : (i + 1) * smooth_split_out_features] = bias[i][
:, tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features
]
def split_column_copy(smooth_linear, para_linear, tp_rank=0, split_num=1):
qweights = smooth_linear.weight.split(smooth_linear.in_features // split_num, dim=-1)
smooth_split_in_features = para_linear.in_features // split_num
for i in range(split_num):
para_linear.weight[:, i * smooth_split_in_features : (i + 1) * smooth_split_in_features] = qweights[i][
:, tp_rank * smooth_split_in_features : (tp_rank + 1) * smooth_split_in_features
]
if smooth_linear.bias is not None:
para_linear.bias.copy_(smooth_linear.bias)
class RowW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
out_features = module.out_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowW8A8B8O8Linear(module.in_features, module.out_features // tp_size)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
linear_1d.b = module.b.clone().detach()
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d
class ColW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = ColW8A8B8O8Linear(module.in_features // tp_size, module.out_features)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = torch.tensor(module.a)
linear_1d.b = torch.tensor(module.b)
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
if linear_1d.bias is not None:
linear_1d.bias = linear_1d.bias // tp_size
return linear_1d
@torch.no_grad()
def forward(self, x):
output = super().forward(x)
if self.tp_size > 1:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
return output
class RowW8A8BFP32O32LinearSiLU(W8A8BFP32O32LinearSiLU, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
out_features = module.out_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowW8A8BFP32O32LinearSiLU(module.in_features, module.out_features // tp_size)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d
class RowW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
out_features = module.out_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowW8A8BFP32OFP32Linear(module.in_features, module.out_features // tp_size)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d
class ColW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = ColW8A8BFP32OFP32Linear(module.in_features // tp_size, module.out_features)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
if linear_1d.bias is not None:
linear_1d.bias = linear_1d.bias / tp_size
return linear_1d
@torch.no_grad()
def forward(self, x):
output = super().forward(x)
if self.tp_size > 1:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
return output

View File

@ -37,6 +37,7 @@ class ShardConfig:
inference_gptq: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
quant: str = None
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@ -77,4 +78,3 @@ class ShardConfig:
Set default params for inference.
"""
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
pass

View File

@ -0,0 +1,79 @@
import argparse
import os
import torch
import torch.distributed as dist
from auto_gptq import AutoGPTQForCausalLM
import colossalai
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
from colossalai.logging import disable_existing_loggers
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def run_llama_test(args):
quantized_model_dir = args.quantized_path
max_batch_size = args.max_batch_size
max_input_len = args.max_input_len
max_output_len = args.max_output_len
micro_batch_size = args.micro_batch_size
# load quantized model to the first GPU
model = AutoGPTQForCausalLM.from_quantized(
quantized_model_dir, inject_fused_attention=False, device=torch.cuda.current_device()
)
engine = CaiInferEngine(
tp_size=2,
pp_size=2,
model=model,
model_policy=LlamaModelInferPolicy(),
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
quant="gptq",
)
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)
output = engine.inference(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
def check_llama(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_llama_test(args)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gptq_llama(args):
spawn(check_llama, args.tp_size * args.pp_size, args=args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size")
parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size")
parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size")
parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size")
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length")
args = parser.parse_args()
test_gptq_llama(args)

View File

@ -0,0 +1,76 @@
import argparse
import torch
import torch.distributed as dist
import colossalai
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
from colossalai.logging import disable_existing_loggers
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
@torch.no_grad()
def run_llama_test(args):
quantized_model_dir = args.quantized_path
max_batch_size = args.max_batch_size
max_input_len = args.max_input_len
max_output_len = args.max_output_len
micro_batch_size = args.micro_batch_size
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)
model = SmoothLlamaForCausalLM.from_quantized(quantized_model_dir, model_basename="llama-7b")
model = model.cuda()
engine = CaiInferEngine(
tp_size=2,
pp_size=2,
model=model,
model_policy=LlamaModelInferPolicy(),
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
quant="smoothquant",
)
output = engine.inference(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == 32, f"{len(output)}, {32}"
def check_llama(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_llama_test(args)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_smoothquant_llama():
spawn(check_llama, args.tp_size * args.pp_size, args=args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size")
parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size")
parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size")
parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size")
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length")
args = parser.parse_args()
test_smoothquant_llama()

View File

@ -9,6 +9,10 @@ from colossalai.inference import BloomModelInferPolicy, CaiInferEngine
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
try:
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
def data_gen():
@ -88,7 +92,10 @@ def check_tp_inference(rank, world_size, port):
run_tp_inference_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()

View File

@ -9,6 +9,10 @@ from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
try:
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
def data_gen():
@ -90,7 +94,10 @@ def check_tp_inference(rank, world_size, port):
run_tp_inference_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()