From e37ee2fb65fc77c275b816968d91776322fd7695 Mon Sep 17 00:00:00 2001 From: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:56:46 +0800 Subject: [PATCH] [Feat]Tensor Model Parallel Support For Inference (#5563) * tensor parallel support naive source * [fix]precision, model load and refactor the framework * add tp unit test * docstring * fix do_sample --- colossalai/inference/core/engine.py | 161 +++++++--- colossalai/inference/core/plugin.py | 140 +++++++++ colossalai/inference/core/request_handler.py | 6 +- .../modeling/models/nopadding_llama.py | 295 +++++++++++++----- .../modeling/policy/nopadding_llama.py | 59 +++- colossalai/inference/utils.py | 53 ++++ tests/test_infer/test_cuda_graph.py | 2 +- tests/test_infer/test_inference_engine.py | 74 +++-- 8 files changed, 640 insertions(+), 150 deletions(-) create mode 100644 colossalai/inference/core/plugin.py diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 466f6749b..c30db3e0c 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -5,8 +5,17 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn -from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from torch import distributed as dist +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + GenerationConfig, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.config import InferenceConfig, InputMetaData @@ -14,6 +23,8 @@ from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.spec import Drafter, GlideInput from colossalai.inference.struct import Sequence +from colossalai.inference.utils import get_model_size, has_index_file +from colossalai.interface import ModelWrapper from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -25,10 +36,10 @@ __all__ = ["InferenceEngine"] PP_AXIS, TP_AXIS = 0, 1 -_supported_models = [ - "LlamaForCausalLM", - "BaichuanForCausalLM", -] +_supported_models = { + "LlamaForCausalLM": LlamaForCausalLM, + "BaichuanForCausalLM": AutoModelForCausalLM, +} _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] @@ -39,7 +50,7 @@ class InferenceEngine: InferenceEngine which manages the inference process.. Args: - model (nn.Module): Path or nn.Module of this model. + model_or_path (nn.Module or str): Path or nn.Module of this model. tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. verbose (bool): Determine whether or not to log the generation process. @@ -48,53 +59,25 @@ class InferenceEngine: def __init__( self, - model: nn.Module, + model_or_path: Union[nn.Module, str], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], inference_config: InferenceConfig, verbose: bool = False, model_policy: Policy = None, ) -> None: self.inference_config = inference_config - self.model_config = model.config - self.model = model - self.device = torch.device("cuda") self.dtype = inference_config.dtype - self.tokenizer = tokenizer - self.tokenizer.pad_token = self.tokenizer.eos_token self.high_precision = inference_config.high_precision - self._verify_args() - - self.generation_config = inference_config.to_generation_config(self.model_config) - model.eval() - model = model.to(self.dtype) - model = model.to(self.device) - - # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` - self.use_spec_dec = False - self.drafter_model = None - self.drafter = None - self.use_glide = False - self.n_spec_tokens = self.inference_config.max_n_spec_tokens - - if model_policy is None: - if self.inference_config.pad_input: - model_type = "padding_" + self.model_config.model_type - else: - model_type = "nopadding_" + self.model_config.model_type - model_policy = model_policy_map[model_type]() - - pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size) - - self.model = self._shardformer( - model, - model_policy, - None, - pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None, - ) self.verbose = verbose - if verbose: - self.logger = get_dist_logger(__name__) + self.logger = get_dist_logger(__name__) + + self.init_model(model_or_path, model_policy) + + self.generation_config = inference_config.to_generation_config(self.model_config) + + self.tokenizer = tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token self.request_handler = RequestHandler(self.inference_config, self.model_config) self.k_cache, self.v_cache = self.request_handler.get_kvcache() @@ -111,6 +94,91 @@ class InferenceEngine: self.capture_model(self.k_cache, self.v_cache) + # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` + self.use_spec_dec = False + self.drafter_model = None + self.drafter = None + self.use_glide = False + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + + self._verify_args() + + def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): + """ + Shard model or/and Load weight + + Args: + model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. + model_policy (Policy): the policy to replace the model + """ + + if isinstance(model_or_path, str): + try: + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + arch = getattr(hf_config, "architectures")[0] + model = _supported_models[arch](hf_config) + except Exception as e: + self.logger.error( + f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" + ) + else: + model = model_or_path + + self.model_config = model.config + + torch.cuda.empty_cache() + init_gpu_memory = torch.cuda.mem_get_info()[0] + + self.device = get_accelerator().get_current_device() + if self.verbose: + self.logger.info(f"the device is {self.device}") + + model = model.to(self.dtype).eval() + + if self.verbose: + self.logger.info( + f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + ) + + if model_policy is None: + if self.inference_config.pad_input: + model_type = "padding_" + self.model_config.model_type + else: + model_type = "nopadding_" + self.model_config.model_type + model_policy = model_policy_map[model_type]() + + pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) + tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + + self.model = self._shardformer( + model, + model_policy, + None, + tp_group=tp_group, + ) + + self.model = ModelWrapper(model).to(self.device) + + if self.verbose: + self.logger.info( + f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + ) + + if isinstance(model_or_path, str): + from colossalai.inference.core.plugin import InferCheckpoint_io + + cpt_io = InferCheckpoint_io() + if_has_index_file, model_index_file = has_index_file(model_or_path) + assert if_has_index_file, "the model path is invalid" + cpt_io.load_model(self.model, model_index_file) + + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + if self.verbose: + self.logger.info( + f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" + ) + @torch.inference_mode() def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): assert self.use_cuda_graph, "please turn on the cuda graph" @@ -194,8 +262,11 @@ class InferenceEngine: raise TypeError( f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" ) - if self.model.__class__.__name__ not in _supported_models: - raise ValueError(f"Model {self.model.__class__.__name__} is not supported.") + if isinstance(self.model, ModelWrapper): + model = self.model.module + assert ( + model.__class__.__name__ in _supported_models.keys() + ), f"Model {self.model.__class__.__name__} is not supported." def _shardformer( self, diff --git a/colossalai/inference/core/plugin.py b/colossalai/inference/core/plugin.py new file mode 100644 index 000000000..d6a2b8b16 --- /dev/null +++ b/colossalai/inference/core/plugin.py @@ -0,0 +1,140 @@ +import logging +import os +from functools import reduce +from pathlib import Path +from typing import Optional + +import torch + +from colossalai.checkpoint_io.general_checkpoint_io import GeneralCheckpointIO +from colossalai.checkpoint_io.index_file import CheckpointIndexFile +from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model +from colossalai.cluster import DistCoordinator +from colossalai.interface import ModelWrapper + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" + + +class InferCheckpoint_io(GeneralCheckpointIO): + """ + This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO. + Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class specifically for Inference. + """ + + def __init__( + self, + verbose: bool = True, + ) -> None: + super().__init__() + self.verbose = verbose + self.coordinator = DistCoordinator() + + def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model = model.unwrap() + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + missing_keys = [] + missing_file_keys = [] + + def _load(name: str): + if name not in weight_map: + missing_file_keys.append(name) + return + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + + load_state_dict_into_model( + model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True + ) + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + # Load buffers. + non_persistent_buffers = set() + for n, m in model.named_modules(): + non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set) + for name, buf in model.named_buffers(): + if buf is not None and name not in non_persistent_buffers: + _load(name) + + # Load extra states. + extra_state_key = _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + _load(extra_state_key) + + if self.verbose and self.coordinator.is_master(): + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + if len(missing_keys) == 0: + raise RuntimeError( + "No weigth is loaded into the model. Please check the checkpoint files and the model structure." + ) + + remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) + remain_keys = remain_keys.union(set(missing_file_keys)) + if len(remain_keys) > 0: + if strict: + error_msgs = "Missing key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in missing_keys) + ) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + else: + if self.coordinator.is_master(): + logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}") + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: + return NotImplementedError diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 327a7e9ce..61ae3a4df 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -140,7 +140,7 @@ class RequestHandler: fd_inter_tensor.initialize( max_batch_size=max_n_tokens, - num_attn_heads=model_config.num_attention_heads, + num_attn_heads=model_config.num_attention_heads // inference_config.tp_size, kv_max_split_num=kv_max_split_num, head_dim=head_dim, dtype=self.dtype, @@ -150,7 +150,7 @@ class RequestHandler: # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, # which may cause bugs and this issue should be fixed later. self.running_bb = BatchBucket( - num_heads=model_config.num_attention_heads, + num_heads=model_config.num_attention_heads // inference_config.tp_size, head_dim=head_dim, max_batch_size=self.max_batch_size, max_length=inference_config.max_input_len + inference_config.max_output_len, @@ -161,7 +161,7 @@ class RequestHandler: device=device, ) self.prefill_bb = BatchBucket( - num_heads=model_config.num_attention_heads, + num_heads=model_config.num_attention_heads // inference_config.tp_size, head_dim=head_dim, max_batch_size=self.max_batch_size, max_length=inference_config.max_input_len + inference_config.max_output_len, diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 5ef576e51..be05e0838 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -1,8 +1,11 @@ # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py -from typing import List, Optional, Tuple +import itertools +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F +from torch import nn +from torch.distributed import ProcessGroup from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, @@ -26,6 +29,8 @@ from colossalai.kernel.triton import ( rotary_embedding, ) from colossalai.logging import get_dist_logger +from colossalai.shardformer.layer.parallel_module import ParallelModule +from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor inference_ops = InferenceOpsLoader().load() @@ -68,7 +73,8 @@ def llama_causal_lm_forward( use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could high_precision=inputmetadata.high_precision, ) - logits = torch.mm(hidden_states, self.lm_head.weight) + + logits = self.lm_head(hidden_states) return logits @@ -109,6 +115,7 @@ def llama_model_forward( logger.warning("CUDA kernel is disabled for speculative-decoding.") hidden_states = self.embed_tokens(input_tokens_ids) + cu_seqlens = None # NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now @@ -126,7 +133,7 @@ def llama_model_forward( cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes]) elif use_cuda_kernel: - if inputmetadata != torch.float32 and use_flash_attn2: + if inputmetadata.dtype != torch.float32 and use_flash_attn2: cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) hidden_dim = self._cos_cached.size(-1) @@ -270,7 +277,129 @@ def llama_rmsnorm_forward( return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual) -class NopadLlamaAttention(LlamaAttention): +class NopadLlamaMLP(ParallelModule, LlamaMLP): + def __init__( + self, + config: LlamaConfig, + mlp_gproj_w: torch.Tensor = None, + mlp_uproj_w: torch.Tensor = None, + mlp_dproj: ParallelModule = None, + process_group: ProcessGroup = None, + ): + """A Unified Layer for + + Args: + config (LlamaConfig): Holding the Llama model config. + mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None. + mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None. + mlp_dproj (Linear1D_Row, optional): The Linear1D_Row mlp_dproj weight. Defaults to None. + """ + ParallelModule.__init__(self) + self.config = config + assert is_distributed_tensor( + mlp_gproj_w + ), "mlp_gproj_w must be dtensor so we could get the layout of the weight" + self.helper_layout = ( + mlp_gproj_w.dist_layout + ) # NOTE this is a hack for the right load/shard of gate_up_weight(used in _load_from_state_dict) + self.gate_up_weight = nn.Parameter( + torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0) + ) + self.down_proj = mlp_dproj + self.process_group = process_group + + @staticmethod + def from_native_module( + module: LlamaMLP, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + """Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP. + + Args: + module (LlamaMLP): The origin LlamaMLP layer. + """ + + config = module.config + + mlp_gproj_w = module.gate_proj.weight + assert is_distributed_tensor( + module.gate_proj.weight + ), "gate_proj.weight must be dtensor so we could get the layout of the weight" + mlp_uproj_w = module.up_proj.weight + mlp_dproj = module.down_proj + + mlp_layer = NopadLlamaMLP( + config=config, + mlp_gproj_w=mlp_gproj_w, + mlp_uproj_w=mlp_uproj_w, + mlp_dproj=mlp_dproj, + process_group=process_group, + ) + + return mlp_layer + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + # NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight) + + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + key = "gate_up_weight" + k1 = "gate_proj.weight" + k2 = "up_proj.weight" + + gate_w = state_dict[prefix + k1] + up_w = state_dict[prefix + k2] + + device_mesh = self.helper_layout.device_mesh + sharding_spec = self.helper_layout.sharding_spec + gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec) + up_w = distribute_tensor(up_w, device_mesh, sharding_spec) + + gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0) + + input_param = nn.Parameter( + gate_up_w + ) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) + param = local_state[key] + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + + strict = False # to avoid unexpected_keys + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. + """ + hidden_states = hidden_states.expand(2, -1, -1) + gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) + act_out = inference_ops.silu_and_mul(gate_up_proj_out) + + return self.down_proj(act_out) + + def extra_repr(self) -> str: + return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False" + + +class NopadLlamaAttention(ParallelModule, LlamaAttention): def __init__( self, config: LlamaConfig, @@ -278,7 +407,11 @@ class NopadLlamaAttention(LlamaAttention): attn_qproj_w: torch.Tensor = None, attn_kproj_w: torch.Tensor = None, attn_vproj_w: torch.Tensor = None, - attn_oproj_w: torch.Tensor = None, + attn_oproj: ParallelModule = None, + process_group: ProcessGroup = None, + num_heads: int = None, + hidden_size: int = None, + num_key_value_heads: int = None, ): """This layer will replace the LlamaAttention. @@ -288,36 +421,54 @@ class NopadLlamaAttention(LlamaAttention): attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None. attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None. attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None. - attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None. + attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None. """ - super().__init__(config, layer_idx) - self.q_proj_weight = attn_qproj_w - self.k_proj_weight = attn_kproj_w - self.v_proj_weight = attn_vproj_w - self.o_proj_weight = attn_oproj_w + ParallelModule.__init__(self) + self.config = config + self.layer_idx = layer_idx + + self.o_proj = attn_oproj + self.process_group = process_group + + self.attention_dropout = config.attention_dropout + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True if self.num_heads == self.num_key_value_heads: - qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight] - self.qkv_weight = torch.stack(qkv_weight_list, dim=0) - - self.q_proj = None - self.k_proj = None - self.v_proj = None + qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] + self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) + self.helper_layout = ( + attn_qproj_w.dist_layout + ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) + else: + self.q_proj_weight = attn_qproj_w + self.k_proj_weight = attn_kproj_w + self.v_proj_weight = attn_vproj_w @staticmethod - def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention: + def from_native_module( + module: LlamaAttention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: """Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention. Args: module (LlamaAttention): The origin LlamaAttention layer. """ + config = module.config layer_idx = module.layer_idx - attn_qproj_w = module.q_proj.weight.transpose(0, 1) - attn_kproj_w = module.k_proj.weight.transpose(0, 1) - attn_vproj_w = module.v_proj.weight.transpose(0, 1) - attn_oproj_w = module.o_proj.weight.transpose(0, 1) + attn_qproj_w = module.q_proj.weight + attn_kproj_w = module.k_proj.weight + attn_vproj_w = module.v_proj.weight + assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor" + attn_oproj = module.o_proj attn_layer = NopadLlamaAttention( config=config, @@ -325,7 +476,11 @@ class NopadLlamaAttention(LlamaAttention): attn_qproj_w=attn_qproj_w, attn_kproj_w=attn_kproj_w, attn_vproj_w=attn_vproj_w, - attn_oproj_w=attn_oproj_w, + attn_oproj=attn_oproj, + process_group=process_group, + num_heads=module.num_heads, + hidden_size=module.hidden_size, + num_key_value_heads=module.num_key_value_heads, ) return attn_layer @@ -487,63 +642,57 @@ class NopadLlamaAttention(LlamaAttention): ) attn_output = attn_output.view(-1, self.hidden_size) - attn_output = torch.mm(attn_output, self.o_proj_weight) - + attn_output = self.o_proj(attn_output) return attn_output - -# NOTE This will cause difference as out length increases. -class NopadLlamaMLP(LlamaMLP): - def __init__( - self, - config: LlamaConfig, - mlp_gproj_w: torch.Tensor = None, - mlp_uproj_w: torch.Tensor = None, - mlp_dproj_w: torch.Tensor = None, + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - """This layer will replace the LlamaAttention. + # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight) + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - Args: - config (LlamaConfig): Holding the Llama model config. - mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None. - mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None. - mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None. - """ - super().__init__(config) - self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0) - self.down_proj_weight = mlp_dproj_w - self.gate_proj = None - self.up_proj = None - self.down_proj = None + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} - @staticmethod - def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP: - """Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP. + key = "qkv_weight" + k1 = "q_proj.weight" + k2 = "k_proj.weight" + k3 = "v_proj.weight" + q_w = state_dict[prefix + k1] + k_w = state_dict[prefix + k2] + v_w = state_dict[prefix + k3] - Args: - module (LlamaMLP): The origin LlamaMLP layer. - """ - config = module.config + device_mesh = self.helper_layout.device_mesh + sharding_spec = self.helper_layout.sharding_spec + q_w = distribute_tensor(q_w, device_mesh, sharding_spec) + k_w = distribute_tensor(k_w, device_mesh, sharding_spec) + v_w = distribute_tensor(v_w, device_mesh, sharding_spec) - mlp_gproj_w = module.gate_proj.weight.transpose(0, 1) - mlp_uproj_w = module.up_proj.weight.transpose(0, 1) - mlp_dproj_w = module.down_proj.weight.transpose(0, 1) + qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0) - mlp_layer = NopadLlamaMLP( - config=config, - mlp_gproj_w=mlp_gproj_w, - mlp_uproj_w=mlp_uproj_w, - mlp_dproj_w=mlp_dproj_w, + input_param = nn.Parameter( + qkv_w + ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) + + param = local_state[key] + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + + strict = False # to avoid unexpected_keys + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) - return mlp_layer - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Args: - hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. - """ - hidden_states = hidden_states.expand(2, -1, -1) - gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) - act_out = inference_ops.silu_and_mul(gate_up_proj_out) - return torch.mm(act_out, self.down_proj_weight) + def extra_repr(self) -> str: + return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False" diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index 292a6e5ff..3cadf601f 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -1,4 +1,3 @@ -from torch.nn import Parameter from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm from colossalai.inference.modeling.models.nopadding_llama import ( @@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import ( llama_rmsnorm_forward, ) from colossalai.inference.utils import init_to_get_rotary +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy @@ -21,26 +21,69 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): def module_policy(self): policy = super().module_policy() - decoder_attribute_replacement = { - "lm_head.weight": Parameter(self.model.lm_head.weight.transpose(0, 1), requires_grad=False), - } - policy[LlamaForCausalLM] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, - ) + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) + else: + decoder_attribute_replacement = None policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ), SubModuleReplacementDescription( suffix="mlp", target_module=NopadLlamaMLP, ), + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), SubModuleReplacementDescription( suffix="self_attn", target_module=NopadLlamaAttention, ), - ] + ], ) + policy[LlamaForCausalLM] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": True} + ) + ], + ) + + # self.shard_config._infer() self.append_or_create_method_replacement( description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM ) diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index a97b9c9d6..9e0d72586 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -2,8 +2,12 @@ Utils for model inference """ import os +import re +from pathlib import Path +from typing import Optional, Tuple import torch +from torch import nn def init_to_get_rotary(self, base=10000, use_elem=False): @@ -49,3 +53,52 @@ def init_to_get_rotary(self, base=10000, use_elem=False): self._cos_cached = torch.cos(freqs).to(self.dtype).cuda() self._sin_cached = torch.sin(freqs).to(self.dtype).cuda() + + +def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: + """ + Check whether the checkpoint has an index file. + + Args: + checkpoint_path (str): path to the checkpoint. + + Returns: + Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path) + """ + checkpoint_path = Path(checkpoint_path) + if checkpoint_path.is_file(): + # check if it is .index.json + reg = re.compile("(.*?).index((\..*)?).json") + if reg.fullmatch(checkpoint_path.name) is not None: + return True, checkpoint_path + else: + return False, None + elif checkpoint_path.is_dir(): + index_files = list(checkpoint_path.glob("*.index.*json")) + + for index_file in index_files: + if "safetensors" in index_file.__str__(): + return True, index_file.__str__() # return the safetensors file first + + if len(index_files) == 1: + return True, index_files[0] + else: + assert ( + len(index_files) == 1 + ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}" + return False, None + else: + raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.") + + +def get_model_size(model: nn.Module): + """Calculates the total size of the model weights (including biases) in bytes. + Args: + model: The PyTorch model to analyze. + Returns: + The total size of the model weights in bytes. + """ + total_size = 0 + for key, param in model.named_parameters(): + total_size += param.element_size() * param.numel() + return total_size / (1024**3) diff --git a/tests/test_infer/test_cuda_graph.py b/tests/test_infer/test_cuda_graph.py index cc5f1c7a2..a0a55d3ad 100644 --- a/tests/test_infer/test_cuda_graph.py +++ b/tests/test_infer/test_cuda_graph.py @@ -40,7 +40,7 @@ def check_inference_engine(use_cuda_graph=False, batch_size=32): input_len = 1024 output_len = 128 - do_sample = True + do_sample = False top_p = 0.5 top_k = 50 diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 088b1f5aa..7125ca386 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -3,24 +3,27 @@ import random import numpy as np import pytest import torch +import torch.distributed as dist +from torch.multiprocessing import Manager from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM import colossalai from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig from colossalai.inference.core.engine import InferenceEngine -from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM +from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def setup_seed(seed): torch.manual_seed(seed) + torch.random.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) -def check_inference_engine(use_engine=False, prompt_template=None): +def check_inference_engine(use_engine=False, prompt_template=None, do_sample=True, policy=None): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") model = LlamaForCausalLM( @@ -36,13 +39,19 @@ def check_inference_engine(use_engine=False, prompt_template=None): ] output_len = 38 - do_sample = True + do_sample = do_sample top_p = 0.5 top_k = 50 if use_engine: - inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32") - inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + inference_config = InferenceConfig( + max_output_len=output_len, + prompt_template=prompt_template, + dtype="fp32", + use_cuda_kernel=True, + tp_size=dist.get_world_size(), + ) + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() @@ -69,20 +78,14 @@ def check_inference_engine(use_engine=False, prompt_template=None): return outputs -@parameterize("prompt_template", [None, "llama"]) -def check_output_consistency(prompt_template): - cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template) - transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template) +def run_engine(world_size, **kwargs): + manager = Manager() + result_list = manager.list([-1] * world_size) # Create a shared list - for s1, s2 in zip(cai_outputs, transformer_outputs): - assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}" - - # clear singleton flash decoding tensors - FDIntermTensors._instances = {} + spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs) + return result_list[0] -@parameterize("num_layers", [1]) -@parameterize("max_length", [100]) def check_spec_dec(num_layers, max_length): torch.manual_seed(123) @@ -152,16 +155,47 @@ def check_spec_dec(num_layers, max_length): assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") - check_output_consistency() - check_spec_dec() + + if ret: + ret[rank] = func_to_run(**kwargs) + else: + func_to_run(**kwargs) + + +@parameterize("prompt_template", [None, "llama"]) +@parameterize("do_sample", [False]) +def test_tp_engine(prompt_template, do_sample): + kwargs1 = { + "use_engine": True, + "prompt_template": prompt_template, + "do_sample": do_sample, + "policy": NoPaddingLlamaModelInferPolicy(), + } + + kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None} + + colossal_tp_1_output = run_engine(1, **kwargs1) + colossal_tp_2_output = run_engine(2, **kwargs1) + transformer_tp_1_output = run_engine(1, **kwargs2) + + for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): + assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" + assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" + + +@parameterize("num_layers", [1]) +@parameterize("max_length", [100]) +def test_spec_dec(num_layers, max_length): + spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length) @pytest.mark.dist @rerun_if_address_is_in_use() def test_inference_engine(): - spawn(run_dist, 1) + test_tp_engine() + test_spec_dec() if __name__ == "__main__":