mirror of https://github.com/hpcaitech/ColossalAI
[refactor] refactor gptq and smoothquant llama (#5012)
* refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcudapull/5023/head^2
parent
48d0a58d10
commit
450115bd0f
|
@ -14,10 +14,7 @@ from ..tensor_parallel.kvcache_manager import MemoryManager
|
|||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = [
|
||||
"LlamaForCausalLM",
|
||||
"BloomForCausalLM",
|
||||
]
|
||||
_supported_models = ["LlamaForCausalLM", "BloomForCausalLM", "LlamaGPTQForCausalLM", "SmoothLlamaForCausalLM"]
|
||||
|
||||
|
||||
class CaiInferEngine:
|
||||
|
@ -70,12 +67,21 @@ class CaiInferEngine:
|
|||
max_batch_size: int = 4,
|
||||
max_input_len: int = 32,
|
||||
max_output_len: int = 32,
|
||||
quant: str = None,
|
||||
verbose: bool = False,
|
||||
# TODO: implement early_stopping, and various gerneration options
|
||||
early_stopping: bool = False,
|
||||
do_sample: bool = False,
|
||||
num_beams: int = 1,
|
||||
) -> None:
|
||||
if quant == "gptq":
|
||||
from ..quant.gptq import GPTQManager
|
||||
|
||||
self.gptq_manager = GPTQManager(model.quantize_config, max_input_len=max_input_len)
|
||||
model = model.model
|
||||
elif quant == "smoothquant":
|
||||
model = model.model
|
||||
|
||||
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
|
||||
assert (
|
||||
tp_size * pp_size == dist.get_world_size()
|
||||
|
@ -85,9 +91,14 @@ class CaiInferEngine:
|
|||
|
||||
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
|
||||
|
||||
assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
||||
self.pp_size = pp_size
|
||||
self.tp_size = tp_size
|
||||
self.quant = quant
|
||||
|
||||
if quant == "smoothquant" and dtype != "fp32":
|
||||
dtype = "fp32"
|
||||
print("Warning: smoothquant only support fp32 and int8 mix precision. set dtype to fp32")
|
||||
|
||||
if dtype == "fp16":
|
||||
self.dtype = torch.float16
|
||||
|
@ -118,6 +129,8 @@ class CaiInferEngine:
|
|||
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
||||
|
||||
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
|
||||
if quant == "gptq":
|
||||
self.gptq_manager.post_init_gptq_buffer(self.model)
|
||||
|
||||
def inference(self, input_list):
|
||||
"""
|
||||
|
@ -149,6 +162,7 @@ class CaiInferEngine:
|
|||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
quant=self.quant,
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
|
@ -158,7 +172,7 @@ class CaiInferEngine:
|
|||
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
|
||||
if model.config.model_type == "llama":
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
head_num = model.config.num_attention_heads // self.tp_size
|
||||
head_num = model.config.num_key_value_heads // self.tp_size
|
||||
num_hidden_layers = (
|
||||
model.config.num_hidden_layers
|
||||
if hasattr(model.config, "num_hidden_layers")
|
||||
|
@ -171,5 +185,8 @@ class CaiInferEngine:
|
|||
num_hidden_layers = model.config.n_layer
|
||||
layer_num = num_hidden_layers // self.pp_size
|
||||
|
||||
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
|
||||
if self.quant == "smoothquant":
|
||||
cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
|
||||
else:
|
||||
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
|
||||
return cache_manager
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from .bloom import BloomInferenceForwards
|
||||
from .llama import LlamaInferenceForwards
|
||||
|
||||
__all__ = ["LlamaInferenceForwards"]
|
||||
__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards"]
|
||||
|
|
|
@ -45,14 +45,15 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.inference_gptq:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if self.shard_config.quant == "gptq":
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
|
@ -94,6 +95,55 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
],
|
||||
)
|
||||
|
||||
elif self.shard_config.quant == "smoothquant":
|
||||
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
|
||||
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
|
||||
ColW8A8BFP32OFP32Linear,
|
||||
RowW8A8B8O8Linear,
|
||||
RowW8A8BFP32O32LinearSiLU,
|
||||
RowW8A8BFP32OFP32Linear,
|
||||
)
|
||||
|
||||
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=RowW8A8B8O8Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=RowW8A8B8O8Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=RowW8A8B8O8Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=ColW8A8BFP32OFP32Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=RowW8A8BFP32O32LinearSiLU,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=RowW8A8BFP32OFP32Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=ColW8A8BFP32OFP32Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
],
|
||||
)
|
||||
self.shard_config._infer()
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_model_forward
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .smoothquant.models.llama import SmoothLlamaForCausalLM
|
|
@ -2,3 +2,4 @@ from .cai_gptq import HAS_AUTO_GPTQ
|
|||
|
||||
if HAS_AUTO_GPTQ:
|
||||
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
|
||||
from .gptq_manager import GPTQManager
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import torch
|
||||
|
||||
|
||||
class GPTQManager:
|
||||
def __init__(self, quant_config, max_input_len: int = 1):
|
||||
self.max_dq_buffer_size = 1
|
||||
self.max_inner_outer_dim = 1
|
||||
self.bits = quant_config.bits
|
||||
self.use_act_order = quant_config.desc_act
|
||||
self.max_input_len = 1
|
||||
self.gptq_temp_state_buffer = None
|
||||
self.gptq_temp_dq_buffer = None
|
||||
self.quant_config = quant_config
|
||||
|
||||
def post_init_gptq_buffer(self, model: torch.nn.Module) -> None:
|
||||
from .cai_gptq import CaiQuantLinear
|
||||
|
||||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn("CUDA gptq is not installed")
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
for name, submodule in model.named_modules():
|
||||
if isinstance(submodule, CaiQuantLinear):
|
||||
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
|
||||
|
||||
if self.use_act_order:
|
||||
self.max_inner_outer_dim = max(
|
||||
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
|
||||
)
|
||||
self.bits = submodule.bits
|
||||
if not (HAS_GPTQ_CUDA and self.bits == 4):
|
||||
return
|
||||
|
||||
max_input_len = 1
|
||||
if self.use_act_order:
|
||||
max_input_len = self.max_input_len
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
self.gptq_temp_state_buffer = torch.zeros(
|
||||
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
||||
)
|
||||
self.gptq_temp_dq_buffer = torch.zeros(
|
||||
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
gptq_cuda.prepare_buffers(
|
||||
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
|
||||
)
|
||||
# Using the default from exllama repo here.
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
torch.cuda.empty_cache()
|
|
@ -4,9 +4,7 @@ try:
|
|||
HAS_TORCH_INT = True
|
||||
except ImportError:
|
||||
HAS_TORCH_INT = False
|
||||
raise ImportError(
|
||||
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
|
||||
)
|
||||
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||
|
||||
if HAS_TORCH_INT:
|
||||
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
|
||||
|
|
|
@ -9,7 +9,6 @@ from functools import partial
|
|||
from os.path import isdir, isfile, join
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import accelerate
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -24,6 +23,15 @@ from transformers.utils.hub import PushToHubMixin, cached_file
|
|||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
|
||||
|
||||
try:
|
||||
import accelerate
|
||||
|
||||
HAS_ACCELERATE = True
|
||||
except ImportError:
|
||||
HAS_ACCELERATE = False
|
||||
print("accelerate is not installed.")
|
||||
|
||||
|
||||
SUPPORTED_MODELS = ["llama"]
|
||||
|
||||
|
||||
|
|
|
@ -1,17 +1,25 @@
|
|||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
|
||||
import torch
|
||||
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
|
||||
from torch_int.functional.quantization import quantize_per_tensor_absmax
|
||||
|
||||
try:
|
||||
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
|
||||
from torch_int.functional.quantization import quantize_per_tensor_absmax
|
||||
|
||||
HAS_TORCH_INT = True
|
||||
except ImportError:
|
||||
HAS_TORCH_INT = False
|
||||
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||
|
||||
|
||||
try:
|
||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||
|
||||
smoothquant_cuda = SmoothquantBuilder().load()
|
||||
HAS_SMOOTHQUANT_CUDA = True
|
||||
except ImportError:
|
||||
except:
|
||||
HAS_SMOOTHQUANT_CUDA = False
|
||||
raise ImportError("CUDA smoothquant linear is not installed")
|
||||
print("CUDA smoothquant linear is not installed")
|
||||
|
||||
|
||||
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
|
||||
|
@ -138,21 +146,23 @@ class W8A8BFP32OFP32Linear(torch.nn.Module):
|
|||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
|
||||
torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
|
||||
def _apply(self, fn):
|
||||
# prevent the bias from being converted to half
|
||||
super()._apply(fn)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
@ -18,7 +17,6 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaDecoderLayer,
|
||||
LlamaMLP,
|
||||
LlamaRotaryEmbedding,
|
||||
repeat_kv,
|
||||
rotate_half,
|
||||
)
|
||||
from transformers.utils import add_start_docstrings_to_model_forward
|
||||
|
@ -31,10 +29,31 @@ from colossalai.kernel.triton import (
|
|||
smooth_token_attention_fwd,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
|
||||
|
||||
HAS_TORCH_INT = True
|
||||
except ImportError:
|
||||
HAS_TORCH_INT = False
|
||||
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||
|
||||
|
||||
from .base_model import BaseSmoothForCausalLM
|
||||
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class LLamaSmoothquantAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -116,7 +135,6 @@ class LLamaSmoothquantAttention(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
rotary_emb: Tuple[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
|
@ -131,8 +149,7 @@ class LLamaSmoothquantAttention(nn.Module):
|
|||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
cos = rotary_emb[0]
|
||||
sin = rotary_emb[1]
|
||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||
|
||||
int8_rotary_embedding_fwd(
|
||||
query_states.view(-1, self.num_heads, self.head_dim),
|
||||
|
@ -348,7 +365,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
rotary_emb: Tuple[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
|
@ -378,7 +394,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module):
|
|||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
rotary_emb=rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
|
@ -650,15 +665,15 @@ def llama_model_forward(
|
|||
raise NotImplementedError("not implement gradient_checkpointing and training options ")
|
||||
|
||||
if past_key_values_length == 0:
|
||||
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
else:
|
||||
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
@ -673,7 +688,6 @@ def llama_model_forward(
|
|||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
rotary_emb=(position_cos, position_sin),
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
|
|
|
@ -0,0 +1,264 @@
|
|||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import ParallelModule
|
||||
|
||||
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
||||
|
||||
|
||||
def split_row_copy(smooth_linear, para_linear, tp_size=1, tp_rank=0, split_num=1):
|
||||
qweights = smooth_linear.weight.split(smooth_linear.out_features // split_num, dim=0)
|
||||
if smooth_linear.bias is not None:
|
||||
bias = smooth_linear.bias.split(smooth_linear.out_features // split_num, dim=0)
|
||||
|
||||
smooth_split_out_features = para_linear.out_features // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
para_linear.weight[i * smooth_split_out_features : (i + 1) * smooth_split_out_features, :] = qweights[i][
|
||||
tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features, :
|
||||
]
|
||||
|
||||
if para_linear.bias is not None:
|
||||
para_linear.bias[:, i * smooth_split_out_features : (i + 1) * smooth_split_out_features] = bias[i][
|
||||
:, tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features
|
||||
]
|
||||
|
||||
|
||||
def split_column_copy(smooth_linear, para_linear, tp_rank=0, split_num=1):
|
||||
qweights = smooth_linear.weight.split(smooth_linear.in_features // split_num, dim=-1)
|
||||
|
||||
smooth_split_in_features = para_linear.in_features // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
para_linear.weight[:, i * smooth_split_in_features : (i + 1) * smooth_split_in_features] = qweights[i][
|
||||
:, tp_rank * smooth_split_in_features : (tp_rank + 1) * smooth_split_in_features
|
||||
]
|
||||
|
||||
if smooth_linear.bias is not None:
|
||||
para_linear.bias.copy_(smooth_linear.bias)
|
||||
|
||||
|
||||
class RowW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__(in_features, out_features, alpha, beta)
|
||||
self.process_group = None
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
out_features = module.out_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if out_features < tp_size:
|
||||
return module
|
||||
|
||||
if out_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = RowW8A8B8O8Linear(module.in_features, module.out_features // tp_size)
|
||||
linear_1d.tp_size = tp_size
|
||||
linear_1d.tp_rank = tp_rank
|
||||
linear_1d.process_group = process_group
|
||||
linear_1d.a = module.a.clone().detach()
|
||||
linear_1d.b = module.b.clone().detach()
|
||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
return linear_1d
|
||||
|
||||
|
||||
class ColW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__(in_features, out_features, alpha, beta)
|
||||
self.process_group = None
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if in_features < tp_size:
|
||||
return module
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = ColW8A8B8O8Linear(module.in_features // tp_size, module.out_features)
|
||||
linear_1d.tp_size = tp_size
|
||||
linear_1d.tp_rank = tp_rank
|
||||
linear_1d.process_group = process_group
|
||||
linear_1d.a = torch.tensor(module.a)
|
||||
linear_1d.b = torch.tensor(module.b)
|
||||
|
||||
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
if linear_1d.bias is not None:
|
||||
linear_1d.bias = linear_1d.bias // tp_size
|
||||
|
||||
return linear_1d
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
output = super().forward(x)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||
return output
|
||||
|
||||
|
||||
class RowW8A8BFP32O32LinearSiLU(W8A8BFP32O32LinearSiLU, ParallelModule):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__(in_features, out_features, alpha, beta)
|
||||
self.process_group = None
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
out_features = module.out_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if out_features < tp_size:
|
||||
return module
|
||||
|
||||
if out_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = RowW8A8BFP32O32LinearSiLU(module.in_features, module.out_features // tp_size)
|
||||
linear_1d.tp_size = tp_size
|
||||
linear_1d.tp_rank = tp_rank
|
||||
linear_1d.process_group = process_group
|
||||
linear_1d.a = module.a.clone().detach()
|
||||
|
||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
return linear_1d
|
||||
|
||||
|
||||
class RowW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__(in_features, out_features, alpha, beta)
|
||||
self.process_group = None
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
out_features = module.out_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if out_features < tp_size:
|
||||
return module
|
||||
|
||||
if out_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = RowW8A8BFP32OFP32Linear(module.in_features, module.out_features // tp_size)
|
||||
linear_1d.tp_size = tp_size
|
||||
linear_1d.tp_rank = tp_rank
|
||||
linear_1d.process_group = process_group
|
||||
linear_1d.a = module.a.clone().detach()
|
||||
|
||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
return linear_1d
|
||||
|
||||
|
||||
class ColW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__(in_features, out_features, alpha, beta)
|
||||
self.process_group = None
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if in_features < tp_size:
|
||||
return module
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = ColW8A8BFP32OFP32Linear(module.in_features // tp_size, module.out_features)
|
||||
linear_1d.tp_size = tp_size
|
||||
linear_1d.tp_rank = tp_rank
|
||||
linear_1d.process_group = process_group
|
||||
linear_1d.a = module.a.clone().detach()
|
||||
|
||||
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
if linear_1d.bias is not None:
|
||||
linear_1d.bias = linear_1d.bias / tp_size
|
||||
|
||||
return linear_1d
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
output = super().forward(x)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||
return output
|
|
@ -37,6 +37,7 @@ class ShardConfig:
|
|||
inference_gptq: bool = False
|
||||
enable_sequence_parallelism: bool = False
|
||||
enable_sequence_overlap: bool = False
|
||||
quant: str = None
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
|
@ -77,4 +78,3 @@ class ShardConfig:
|
|||
Set default params for inference.
|
||||
"""
|
||||
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def run_llama_test(args):
|
||||
quantized_model_dir = args.quantized_path
|
||||
max_batch_size = args.max_batch_size
|
||||
max_input_len = args.max_input_len
|
||||
max_output_len = args.max_output_len
|
||||
micro_batch_size = args.micro_batch_size
|
||||
# load quantized model to the first GPU
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
quantized_model_dir, inject_fused_attention=False, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
engine = CaiInferEngine(
|
||||
tp_size=2,
|
||||
pp_size=2,
|
||||
model=model,
|
||||
model_policy=LlamaModelInferPolicy(),
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
micro_batch_size=micro_batch_size,
|
||||
quant="gptq",
|
||||
)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
inputs = data_gen()
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 16
|
||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
output = engine.inference(inputs)
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_llama_test(args)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gptq_llama(args):
|
||||
spawn(check_llama, args.tp_size * args.pp_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size")
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size")
|
||||
parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size")
|
||||
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
||||
parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length")
|
||||
args = parser.parse_args()
|
||||
|
||||
test_gptq_llama(args)
|
|
@ -0,0 +1,76 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
|
||||
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_llama_test(args):
|
||||
quantized_model_dir = args.quantized_path
|
||||
max_batch_size = args.max_batch_size
|
||||
max_input_len = args.max_input_len
|
||||
max_output_len = args.max_output_len
|
||||
micro_batch_size = args.micro_batch_size
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
inputs = data_gen()
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 16
|
||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
model = SmoothLlamaForCausalLM.from_quantized(quantized_model_dir, model_basename="llama-7b")
|
||||
model = model.cuda()
|
||||
|
||||
engine = CaiInferEngine(
|
||||
tp_size=2,
|
||||
pp_size=2,
|
||||
model=model,
|
||||
model_policy=LlamaModelInferPolicy(),
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
micro_batch_size=micro_batch_size,
|
||||
quant="smoothquant",
|
||||
)
|
||||
|
||||
output = engine.inference(inputs)
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == 32, f"{len(output)}, {32}"
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_llama_test(args)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_smoothquant_llama():
|
||||
spawn(check_llama, args.tp_size * args.pp_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size")
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size")
|
||||
parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size")
|
||||
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
||||
parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length")
|
||||
|
||||
args = parser.parse_args()
|
||||
test_smoothquant_llama()
|
|
@ -9,6 +9,10 @@ from colossalai.inference import BloomModelInferPolicy, CaiInferEngine
|
|||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
try:
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
|
||||
def data_gen():
|
||||
|
@ -88,7 +92,10 @@ def check_tp_inference(rank, world_size, port):
|
|||
run_tp_inference_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.skipif(
|
||||
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||
)
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
|
|
@ -9,6 +9,10 @@ from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
|
|||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
try:
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
|
||||
def data_gen():
|
||||
|
@ -90,7 +94,10 @@ def check_tp_inference(rank, world_size, port):
|
|||
run_tp_inference_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.skipif(
|
||||
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||
)
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
|
Loading…
Reference in New Issue