diff --git a/colossalai/inference/hybridengine/engine.py b/colossalai/inference/hybridengine/engine.py index 43824858a..3a80723c3 100644 --- a/colossalai/inference/hybridengine/engine.py +++ b/colossalai/inference/hybridengine/engine.py @@ -14,10 +14,7 @@ from ..tensor_parallel.kvcache_manager import MemoryManager PP_AXIS, TP_AXIS = 0, 1 -_supported_models = [ - "LlamaForCausalLM", - "BloomForCausalLM", -] +_supported_models = ["LlamaForCausalLM", "BloomForCausalLM", "LlamaGPTQForCausalLM", "SmoothLlamaForCausalLM"] class CaiInferEngine: @@ -70,12 +67,21 @@ class CaiInferEngine: max_batch_size: int = 4, max_input_len: int = 32, max_output_len: int = 32, + quant: str = None, verbose: bool = False, # TODO: implement early_stopping, and various gerneration options early_stopping: bool = False, do_sample: bool = False, num_beams: int = 1, ) -> None: + if quant == "gptq": + from ..quant.gptq import GPTQManager + + self.gptq_manager = GPTQManager(model.quantize_config, max_input_len=max_input_len) + model = model.model + elif quant == "smoothquant": + model = model.model + assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported." assert ( tp_size * pp_size == dist.get_world_size() @@ -85,9 +91,14 @@ class CaiInferEngine: assert max_batch_size <= 64, "Max batch size exceeds the constraint" assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint" - + assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'" self.pp_size = pp_size self.tp_size = tp_size + self.quant = quant + + if quant == "smoothquant" and dtype != "fp32": + dtype = "fp32" + print("Warning: smoothquant only support fp32 and int8 mix precision. set dtype to fp32") if dtype == "fp16": self.dtype = torch.float16 @@ -118,6 +129,8 @@ class CaiInferEngine: self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose) self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS)) + if quant == "gptq": + self.gptq_manager.post_init_gptq_buffer(self.model) def inference(self, input_list): """ @@ -149,6 +162,7 @@ class CaiInferEngine: enable_flash_attention=False, enable_jit_fused=False, enable_sequence_parallelism=False, + quant=self.quant, ) shardformer = ShardFormer(shard_config=shardconfig) shard_model, _ = shardformer.optimize(model, model_policy) @@ -158,7 +172,7 @@ class CaiInferEngine: max_total_token_num = max_batch_size * (max_input_len + max_output_len) if model.config.model_type == "llama": head_dim = model.config.hidden_size // model.config.num_attention_heads - head_num = model.config.num_attention_heads // self.tp_size + head_num = model.config.num_key_value_heads // self.tp_size num_hidden_layers = ( model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") @@ -171,5 +185,8 @@ class CaiInferEngine: num_hidden_layers = model.config.n_layer layer_num = num_hidden_layers // self.pp_size - cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num) + if self.quant == "smoothquant": + cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) + else: + cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num) return cache_manager diff --git a/colossalai/inference/hybridengine/modeling/__init__.py b/colossalai/inference/hybridengine/modeling/__init__.py index 239bdebd7..a6603066a 100644 --- a/colossalai/inference/hybridengine/modeling/__init__.py +++ b/colossalai/inference/hybridengine/modeling/__init__.py @@ -1,3 +1,4 @@ +from .bloom import BloomInferenceForwards from .llama import LlamaInferenceForwards -__all__ = ["LlamaInferenceForwards"] +__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards"] diff --git a/colossalai/inference/hybridengine/polices/llama.py b/colossalai/inference/hybridengine/polices/llama.py index 352562f1d..3cdfc0173 100644 --- a/colossalai/inference/hybridengine/polices/llama.py +++ b/colossalai/inference/hybridengine/polices/llama.py @@ -45,14 +45,15 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): def module_policy(self): policy = super().module_policy() - - if self.shard_config.inference_gptq: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.num_key_value_heads": self.model.config.num_key_value_heads + // self.shard_config.tensor_parallel_size, + } + if self.shard_config.quant == "gptq": from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear - decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - } policy[LlamaDecoderLayer] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ @@ -94,6 +95,55 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): ], ) + elif self.shard_config.quant == "smoothquant": + from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer + from colossalai.inference.quant.smoothquant.models.parallel_linear import ( + ColW8A8BFP32OFP32Linear, + RowW8A8B8O8Linear, + RowW8A8BFP32O32LinearSiLU, + RowW8A8BFP32OFP32Linear, + ) + + policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=RowW8A8B8O8Linear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=RowW8A8B8O8Linear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=RowW8A8B8O8Linear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=ColW8A8BFP32OFP32Linear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=RowW8A8BFP32O32LinearSiLU, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=RowW8A8BFP32OFP32Linear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=ColW8A8BFP32OFP32Linear, + kwargs={"split_num": 1}, + ), + ], + ) self.shard_config._infer() infer_forward = LlamaInferenceForwards.llama_model_forward diff --git a/colossalai/inference/quant/__init__.py b/colossalai/inference/quant/__init__.py new file mode 100644 index 000000000..18e0de9cc --- /dev/null +++ b/colossalai/inference/quant/__init__.py @@ -0,0 +1 @@ +from .smoothquant.models.llama import SmoothLlamaForCausalLM diff --git a/colossalai/inference/quant/gptq/__init__.py b/colossalai/inference/quant/gptq/__init__.py index c035f3979..4cf1fd658 100644 --- a/colossalai/inference/quant/gptq/__init__.py +++ b/colossalai/inference/quant/gptq/__init__.py @@ -2,3 +2,4 @@ from .cai_gptq import HAS_AUTO_GPTQ if HAS_AUTO_GPTQ: from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear + from .gptq_manager import GPTQManager diff --git a/colossalai/inference/quant/gptq/gptq_manager.py b/colossalai/inference/quant/gptq/gptq_manager.py new file mode 100644 index 000000000..2d352fbef --- /dev/null +++ b/colossalai/inference/quant/gptq/gptq_manager.py @@ -0,0 +1,61 @@ +import torch + + +class GPTQManager: + def __init__(self, quant_config, max_input_len: int = 1): + self.max_dq_buffer_size = 1 + self.max_inner_outer_dim = 1 + self.bits = quant_config.bits + self.use_act_order = quant_config.desc_act + self.max_input_len = 1 + self.gptq_temp_state_buffer = None + self.gptq_temp_dq_buffer = None + self.quant_config = quant_config + + def post_init_gptq_buffer(self, model: torch.nn.Module) -> None: + from .cai_gptq import CaiQuantLinear + + HAS_GPTQ_CUDA = False + try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True + except ImportError: + warnings.warn("CUDA gptq is not installed") + HAS_GPTQ_CUDA = False + + for name, submodule in model.named_modules(): + if isinstance(submodule, CaiQuantLinear): + self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) + + if self.use_act_order: + self.max_inner_outer_dim = max( + self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures + ) + self.bits = submodule.bits + if not (HAS_GPTQ_CUDA and self.bits == 4): + return + + max_input_len = 1 + if self.use_act_order: + max_input_len = self.max_input_len + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + self.gptq_temp_state_buffer = torch.zeros( + (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() + ) + self.gptq_temp_dq_buffer = torch.zeros( + (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device() + ) + + gptq_cuda.prepare_buffers( + torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer + ) + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + torch.cuda.empty_cache() diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py index 77541d861..1663028da 100644 --- a/colossalai/inference/quant/smoothquant/models/__init__.py +++ b/colossalai/inference/quant/smoothquant/models/__init__.py @@ -4,9 +4,7 @@ try: HAS_TORCH_INT = True except ImportError: HAS_TORCH_INT = False - raise ImportError( - "Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int" - ) + print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") if HAS_TORCH_INT: from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 9554be9ea..9fe3241cf 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -9,7 +9,6 @@ from functools import partial from os.path import isdir, isfile, join from typing import Dict, List, Optional, Union -import accelerate import numpy as np import torch import torch.nn as nn @@ -24,6 +23,15 @@ from transformers.utils.hub import PushToHubMixin, cached_file from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager +try: + import accelerate + + HAS_ACCELERATE = True +except ImportError: + HAS_ACCELERATE = False + print("accelerate is not installed.") + + SUPPORTED_MODELS = ["llama"] diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py index 969c390a0..03d994b32 100644 --- a/colossalai/inference/quant/smoothquant/models/linear.py +++ b/colossalai/inference/quant/smoothquant/models/linear.py @@ -1,17 +1,25 @@ # modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py import torch -from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32 -from torch_int.functional.quantization import quantize_per_tensor_absmax + +try: + from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32 + from torch_int.functional.quantization import quantize_per_tensor_absmax + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + try: from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder smoothquant_cuda = SmoothquantBuilder().load() HAS_SMOOTHQUANT_CUDA = True -except ImportError: +except: HAS_SMOOTHQUANT_CUDA = False - raise ImportError("CUDA smoothquant linear is not installed") + print("CUDA smoothquant linear is not installed") class W8A8BFP32O32LinearSiLU(torch.nn.Module): @@ -138,21 +146,23 @@ class W8A8BFP32OFP32Linear(torch.nn.Module): ) self.register_buffer( "bias", - torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False), + torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False), ) self.register_buffer("a", torch.tensor(alpha)) def _apply(self, fn): # prevent the bias from being converted to half super()._apply(fn) - self.bias = self.bias.to(torch.float32) + if self.bias is not None: + self.bias = self.bias.to(torch.float32) return self def to(self, *args, **kwargs): super().to(*args, **kwargs) self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - self.bias = self.bias.to(torch.float32) + if self.bias is not None: + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) return self @torch.no_grad() diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 30063857a..9d4bd9f77 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -8,7 +8,6 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F -from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig @@ -18,7 +17,6 @@ from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaMLP, LlamaRotaryEmbedding, - repeat_kv, rotate_half, ) from transformers.utils import add_start_docstrings_to_model_forward @@ -31,10 +29,31 @@ from colossalai.kernel.triton import ( smooth_token_attention_fwd, ) +try: + from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + from .base_model import BaseSmoothForCausalLM from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + class LLamaSmoothquantAttention(nn.Module): def __init__( self, @@ -116,7 +135,6 @@ class LLamaSmoothquantAttention(nn.Module): def forward( self, hidden_states: torch.Tensor, - rotary_emb: Tuple[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -131,8 +149,7 @@ class LLamaSmoothquantAttention(nn.Module): key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - cos = rotary_emb[0] - sin = rotary_emb[1] + cos, sin = infer_state.position_cos, infer_state.position_sin int8_rotary_embedding_fwd( query_states.view(-1, self.num_heads, self.head_dim), @@ -348,7 +365,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module): def forward( self, hidden_states: torch.Tensor, - rotary_emb: Tuple[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -378,7 +394,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module): # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - rotary_emb=rotary_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, @@ -650,15 +665,15 @@ def llama_model_forward( raise NotImplementedError("not implement gradient_checkpointing and training options ") if past_key_values_length == 0: - position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 ) - position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 ) else: - position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1) - position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1) + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -673,7 +688,6 @@ def llama_model_forward( layer_outputs = decoder_layer( hidden_states, - rotary_emb=(position_cos, position_sin), attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, diff --git a/colossalai/inference/quant/smoothquant/models/parallel_linear.py b/colossalai/inference/quant/smoothquant/models/parallel_linear.py new file mode 100644 index 000000000..962b687a1 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/parallel_linear.py @@ -0,0 +1,264 @@ +from typing import List, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import ParallelModule + +from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear + + +def split_row_copy(smooth_linear, para_linear, tp_size=1, tp_rank=0, split_num=1): + qweights = smooth_linear.weight.split(smooth_linear.out_features // split_num, dim=0) + if smooth_linear.bias is not None: + bias = smooth_linear.bias.split(smooth_linear.out_features // split_num, dim=0) + + smooth_split_out_features = para_linear.out_features // split_num + + for i in range(split_num): + para_linear.weight[i * smooth_split_out_features : (i + 1) * smooth_split_out_features, :] = qweights[i][ + tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features, : + ] + + if para_linear.bias is not None: + para_linear.bias[:, i * smooth_split_out_features : (i + 1) * smooth_split_out_features] = bias[i][ + :, tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features + ] + + +def split_column_copy(smooth_linear, para_linear, tp_rank=0, split_num=1): + qweights = smooth_linear.weight.split(smooth_linear.in_features // split_num, dim=-1) + + smooth_split_in_features = para_linear.in_features // split_num + + for i in range(split_num): + para_linear.weight[:, i * smooth_split_in_features : (i + 1) * smooth_split_in_features] = qweights[i][ + :, tp_rank * smooth_split_in_features : (tp_rank + 1) * smooth_split_in_features + ] + + if smooth_linear.bias is not None: + para_linear.bias.copy_(smooth_linear.bias) + + +class RowW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__(in_features, out_features, alpha, beta) + self.process_group = None + self.tp_size = 1 + self.tp_rank = 0 + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + out_features = module.out_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = RowW8A8B8O8Linear(module.in_features, module.out_features // tp_size) + linear_1d.tp_size = tp_size + linear_1d.tp_rank = tp_rank + linear_1d.process_group = process_group + linear_1d.a = module.a.clone().detach() + linear_1d.b = module.b.clone().detach() + split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d + + +class ColW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__(in_features, out_features, alpha, beta) + self.process_group = None + self.tp_size = 1 + self.tp_rank = 0 + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = ColW8A8B8O8Linear(module.in_features // tp_size, module.out_features) + linear_1d.tp_size = tp_size + linear_1d.tp_rank = tp_rank + linear_1d.process_group = process_group + linear_1d.a = torch.tensor(module.a) + linear_1d.b = torch.tensor(module.b) + + split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + if linear_1d.bias is not None: + linear_1d.bias = linear_1d.bias // tp_size + + return linear_1d + + @torch.no_grad() + def forward(self, x): + output = super().forward(x) + if self.tp_size > 1: + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) + return output + + +class RowW8A8BFP32O32LinearSiLU(W8A8BFP32O32LinearSiLU, ParallelModule): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__(in_features, out_features, alpha, beta) + self.process_group = None + self.tp_size = 1 + self.tp_rank = 0 + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + out_features = module.out_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = RowW8A8BFP32O32LinearSiLU(module.in_features, module.out_features // tp_size) + linear_1d.tp_size = tp_size + linear_1d.tp_rank = tp_rank + linear_1d.process_group = process_group + linear_1d.a = module.a.clone().detach() + + split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d + + +class RowW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__(in_features, out_features, alpha, beta) + self.process_group = None + self.tp_size = 1 + self.tp_rank = 0 + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + out_features = module.out_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = RowW8A8BFP32OFP32Linear(module.in_features, module.out_features // tp_size) + linear_1d.tp_size = tp_size + linear_1d.tp_rank = tp_rank + linear_1d.process_group = process_group + linear_1d.a = module.a.clone().detach() + + split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d + + +class ColW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__(in_features, out_features, alpha, beta) + self.process_group = None + self.tp_size = 1 + self.tp_rank = 0 + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = ColW8A8BFP32OFP32Linear(module.in_features // tp_size, module.out_features) + linear_1d.tp_size = tp_size + linear_1d.tp_rank = tp_rank + linear_1d.process_group = process_group + linear_1d.a = module.a.clone().detach() + + split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + if linear_1d.bias is not None: + linear_1d.bias = linear_1d.bias / tp_size + + return linear_1d + + @torch.no_grad() + def forward(self, x): + output = super().forward(x) + if self.tp_size > 1: + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) + return output diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 2aa613983..c7d63c234 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -37,6 +37,7 @@ class ShardConfig: inference_gptq: bool = False enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False + quant: str = None # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] @@ -77,4 +78,3 @@ class ShardConfig: Set default params for inference. """ # assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" - pass diff --git a/examples/inference/hybrid_gptq_llama.py b/examples/inference/hybrid_gptq_llama.py new file mode 100644 index 000000000..122307e18 --- /dev/null +++ b/examples/inference/hybrid_gptq_llama.py @@ -0,0 +1,79 @@ +import argparse +import os + +import torch +import torch.distributed as dist +from auto_gptq import AutoGPTQForCausalLM + +import colossalai +from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy +from colossalai.logging import disable_existing_loggers +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def run_llama_test(args): + quantized_model_dir = args.quantized_path + max_batch_size = args.max_batch_size + max_input_len = args.max_input_len + max_output_len = args.max_output_len + micro_batch_size = args.micro_batch_size + # load quantized model to the first GPU + model = AutoGPTQForCausalLM.from_quantized( + quantized_model_dir, inject_fused_attention=False, device=torch.cuda.current_device() + ) + + engine = CaiInferEngine( + tp_size=2, + pp_size=2, + model=model, + model_policy=LlamaModelInferPolicy(), + max_batch_size=max_batch_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + micro_batch_size=micro_batch_size, + quant="gptq", + ) + + def data_gen(): + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + inputs = data_gen() + for k, v in inputs.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 16 + inputs[k] = v.to("cuda").repeat(*new_shape) + + output = engine.inference(inputs) + if dist.get_rank() == 0: + assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" + + +def check_llama(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_llama_test(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gptq_llama(args): + spawn(check_llama, args.tp_size * args.pp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) + parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size") + parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size") + parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size") + parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size") + parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length") + parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length") + args = parser.parse_args() + + test_gptq_llama(args) diff --git a/examples/inference/hybrid_smoothquant_llama.py b/examples/inference/hybrid_smoothquant_llama.py new file mode 100644 index 000000000..6f5b54b35 --- /dev/null +++ b/examples/inference/hybrid_smoothquant_llama.py @@ -0,0 +1,76 @@ +import argparse + +import torch +import torch.distributed as dist + +import colossalai +from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy +from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM +from colossalai.logging import disable_existing_loggers +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + + +@torch.no_grad() +def run_llama_test(args): + quantized_model_dir = args.quantized_path + max_batch_size = args.max_batch_size + max_input_len = args.max_input_len + max_output_len = args.max_output_len + micro_batch_size = args.micro_batch_size + + def data_gen(): + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + inputs = data_gen() + for k, v in inputs.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 16 + inputs[k] = v.to("cuda").repeat(*new_shape) + + model = SmoothLlamaForCausalLM.from_quantized(quantized_model_dir, model_basename="llama-7b") + model = model.cuda() + + engine = CaiInferEngine( + tp_size=2, + pp_size=2, + model=model, + model_policy=LlamaModelInferPolicy(), + max_batch_size=max_batch_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + micro_batch_size=micro_batch_size, + quant="smoothquant", + ) + + output = engine.inference(inputs) + if dist.get_rank() == 0: + assert len(output[0]) == 32, f"{len(output)}, {32}" + + +def check_llama(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_llama_test(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_smoothquant_llama(): + spawn(check_llama, args.tp_size * args.pp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) + parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size") + parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size") + parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size") + parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size") + parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length") + parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length") + + args = parser.parse_args() + test_smoothquant_llama() diff --git a/tests/test_infer/test_hybrid_bloom.py b/tests/test_infer/test_hybrid_bloom.py index 7f68e82c6..14b745982 100644 --- a/tests/test_infer/test_hybrid_bloom.py +++ b/tests/test_infer/test_hybrid_bloom.py @@ -9,6 +9,10 @@ from colossalai.inference import BloomModelInferPolicy, CaiInferEngine from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") +try: + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False def data_gen(): @@ -88,7 +92,10 @@ def check_tp_inference(rank, world_size, port): run_tp_inference_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, + reason="kv-cache manager engine requires cuda version to be higher than 11.5", +) @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_hybrid_llama.py b/tests/test_infer/test_hybrid_llama.py index ca2349b18..d917ae2d8 100644 --- a/tests/test_infer/test_hybrid_llama.py +++ b/tests/test_infer/test_hybrid_llama.py @@ -9,6 +9,10 @@ from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") +try: + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False def data_gen(): @@ -90,7 +94,10 @@ def check_tp_inference(rank, world_size, port): run_tp_inference_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, + reason="kv-cache manager engine requires cuda version to be higher than 11.5", +) @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()