[Feat]Tensor Model Parallel Support For Inference (#5563)

* tensor parallel support naive source

* [fix]precision, model load and refactor the framework

* add tp unit test

* docstring

* fix do_sample
pull/5611/head
Runyu Lu 2024-04-18 16:56:46 +08:00 committed by GitHub
parent be396ad6cc
commit e37ee2fb65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 640 additions and 150 deletions

View File

@ -5,8 +5,17 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from torch import distributed as dist
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
@ -14,6 +23,8 @@ from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@ -25,10 +36,10 @@ __all__ = ["InferenceEngine"]
PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
"BaichuanForCausalLM",
]
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
"BaichuanForCausalLM": AutoModelForCausalLM,
}
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
@ -39,7 +50,7 @@ class InferenceEngine:
InferenceEngine which manages the inference process..
Args:
model (nn.Module): Path or nn.Module of this model.
model_or_path (nn.Module or str): Path or nn.Module of this model.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
@ -48,53 +59,25 @@ class InferenceEngine:
def __init__(
self,
model: nn.Module,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
) -> None:
self.inference_config = inference_config
self.model_config = model.config
self.model = model
self.device = torch.device("cuda")
self.dtype = inference_config.dtype
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.high_precision = inference_config.high_precision
self._verify_args()
self.generation_config = inference_config.to_generation_config(self.model_config)
model.eval()
model = model.to(self.dtype)
model = model.to(self.device)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
if model_policy is None:
if self.inference_config.pad_input:
model_type = "padding_" + self.model_config.model_type
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
self.model = self._shardformer(
model,
model_policy,
None,
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
)
self.verbose = verbose
if verbose:
self.logger = get_dist_logger(__name__)
self.logger = get_dist_logger(__name__)
self.init_model(model_or_path, model_policy)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
@ -111,6 +94,91 @@ class InferenceEngine:
self.capture_model(self.k_cache, self.v_cache)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self._verify_args()
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model
"""
if isinstance(model_or_path, str):
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
arch = getattr(hf_config, "architectures")[0]
model = _supported_models[arch](hf_config)
except Exception as e:
self.logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
)
else:
model = model_or_path
self.model_config = model.config
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")
model = model.to(self.dtype).eval()
if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)
if model_policy is None:
if self.inference_config.pad_input:
model_type = "padding_" + self.model_config.model_type
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
None,
tp_group=tp_group,
)
self.model = ModelWrapper(model).to(self.device)
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
if isinstance(model_or_path, str):
from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(model_or_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
@torch.inference_mode()
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
assert self.use_cuda_graph, "please turn on the cuda graph"
@ -194,8 +262,11 @@ class InferenceEngine:
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
if self.model.__class__.__name__ not in _supported_models:
raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")
if isinstance(self.model, ModelWrapper):
model = self.model.module
assert (
model.__class__.__name__ in _supported_models.keys()
), f"Model {self.model.__class__.__name__} is not supported."
def _shardformer(
self,

View File

@ -0,0 +1,140 @@
import logging
import os
from functools import reduce
from pathlib import Path
from typing import Optional
import torch
from colossalai.checkpoint_io.general_checkpoint_io import GeneralCheckpointIO
from colossalai.checkpoint_io.index_file import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class InferCheckpoint_io(GeneralCheckpointIO):
"""
This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO.
Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class specifically for Inference.
"""
def __init__(
self,
verbose: bool = True,
) -> None:
super().__init__()
self.verbose = verbose
self.coordinator = DistCoordinator()
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model = model.unwrap()
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
missing_keys = []
missing_file_keys = []
def _load(name: str):
if name not in weight_map:
missing_file_keys.append(name)
return
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
_load(name)
# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
_load(extra_state_key)
if self.verbose and self.coordinator.is_master():
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
if len(missing_keys) == 0:
raise RuntimeError(
"No weigth is loaded into the model. Please check the checkpoint files and the model structure."
)
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0:
if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
else:
if self.coordinator.is_master():
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
return NotImplementedError

View File

@ -140,7 +140,7 @@ class RequestHandler:
fd_inter_tensor.initialize(
max_batch_size=max_n_tokens,
num_attn_heads=model_config.num_attention_heads,
num_attn_heads=model_config.num_attention_heads // inference_config.tp_size,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
dtype=self.dtype,
@ -150,7 +150,7 @@ class RequestHandler:
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
@ -161,7 +161,7 @@ class RequestHandler:
device=device,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,

View File

@ -1,8 +1,11 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
from typing import List, Optional, Tuple
import itertools
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributed import ProcessGroup
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
@ -26,6 +29,8 @@ from colossalai.kernel.triton import (
rotary_embedding,
)
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
inference_ops = InferenceOpsLoader().load()
@ -68,7 +73,8 @@ def llama_causal_lm_forward(
use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could
high_precision=inputmetadata.high_precision,
)
logits = torch.mm(hidden_states, self.lm_head.weight)
logits = self.lm_head(hidden_states)
return logits
@ -109,6 +115,7 @@ def llama_model_forward(
logger.warning("CUDA kernel is disabled for speculative-decoding.")
hidden_states = self.embed_tokens(input_tokens_ids)
cu_seqlens = None
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
@ -126,7 +133,7 @@ def llama_model_forward(
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
elif use_cuda_kernel:
if inputmetadata != torch.float32 and use_flash_attn2:
if inputmetadata.dtype != torch.float32 and use_flash_attn2:
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
hidden_dim = self._cos_cached.size(-1)
@ -270,7 +277,129 @@ def llama_rmsnorm_forward(
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
class NopadLlamaAttention(LlamaAttention):
class NopadLlamaMLP(ParallelModule, LlamaMLP):
def __init__(
self,
config: LlamaConfig,
mlp_gproj_w: torch.Tensor = None,
mlp_uproj_w: torch.Tensor = None,
mlp_dproj: ParallelModule = None,
process_group: ProcessGroup = None,
):
"""A Unified Layer for
Args:
config (LlamaConfig): Holding the Llama model config.
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
mlp_dproj (Linear1D_Row, optional): The Linear1D_Row mlp_dproj weight. Defaults to None.
"""
ParallelModule.__init__(self)
self.config = config
assert is_distributed_tensor(
mlp_gproj_w
), "mlp_gproj_w must be dtensor so we could get the layout of the weight"
self.helper_layout = (
mlp_gproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of gate_up_weight(used in _load_from_state_dict)
self.gate_up_weight = nn.Parameter(
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
)
self.down_proj = mlp_dproj
self.process_group = process_group
@staticmethod
def from_native_module(
module: LlamaMLP, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
"""Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
Args:
module (LlamaMLP): The origin LlamaMLP layer.
"""
config = module.config
mlp_gproj_w = module.gate_proj.weight
assert is_distributed_tensor(
module.gate_proj.weight
), "gate_proj.weight must be dtensor so we could get the layout of the weight"
mlp_uproj_w = module.up_proj.weight
mlp_dproj = module.down_proj
mlp_layer = NopadLlamaMLP(
config=config,
mlp_gproj_w=mlp_gproj_w,
mlp_uproj_w=mlp_uproj_w,
mlp_dproj=mlp_dproj,
process_group=process_group,
)
return mlp_layer
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
key = "gate_up_weight"
k1 = "gate_proj.weight"
k2 = "up_proj.weight"
gate_w = state_dict[prefix + k1]
up_w = state_dict[prefix + k2]
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec)
up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0)
input_param = nn.Parameter(
gate_up_w
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state[key]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
"""
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return self.down_proj(act_out)
def extra_repr(self) -> str:
return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False"
class NopadLlamaAttention(ParallelModule, LlamaAttention):
def __init__(
self,
config: LlamaConfig,
@ -278,7 +407,11 @@ class NopadLlamaAttention(LlamaAttention):
attn_qproj_w: torch.Tensor = None,
attn_kproj_w: torch.Tensor = None,
attn_vproj_w: torch.Tensor = None,
attn_oproj_w: torch.Tensor = None,
attn_oproj: ParallelModule = None,
process_group: ProcessGroup = None,
num_heads: int = None,
hidden_size: int = None,
num_key_value_heads: int = None,
):
"""This layer will replace the LlamaAttention.
@ -288,36 +421,54 @@ class NopadLlamaAttention(LlamaAttention):
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
"""
super().__init__(config, layer_idx)
self.q_proj_weight = attn_qproj_w
self.k_proj_weight = attn_kproj_w
self.v_proj_weight = attn_vproj_w
self.o_proj_weight = attn_oproj_w
ParallelModule.__init__(self)
self.config = config
self.layer_idx = layer_idx
self.o_proj = attn_oproj
self.process_group = process_group
self.attention_dropout = config.attention_dropout
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if self.num_heads == self.num_key_value_heads:
qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
self.q_proj = None
self.k_proj = None
self.v_proj = None
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
self.helper_layout = (
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
else:
self.q_proj_weight = attn_qproj_w
self.k_proj_weight = attn_kproj_w
self.v_proj_weight = attn_vproj_w
@staticmethod
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
def from_native_module(
module: LlamaAttention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention.
Args:
module (LlamaAttention): The origin LlamaAttention layer.
"""
config = module.config
layer_idx = module.layer_idx
attn_qproj_w = module.q_proj.weight.transpose(0, 1)
attn_kproj_w = module.k_proj.weight.transpose(0, 1)
attn_vproj_w = module.v_proj.weight.transpose(0, 1)
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
attn_qproj_w = module.q_proj.weight
attn_kproj_w = module.k_proj.weight
attn_vproj_w = module.v_proj.weight
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
attn_oproj = module.o_proj
attn_layer = NopadLlamaAttention(
config=config,
@ -325,7 +476,11 @@ class NopadLlamaAttention(LlamaAttention):
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj_w=attn_oproj_w,
attn_oproj=attn_oproj,
process_group=process_group,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
num_key_value_heads=module.num_key_value_heads,
)
return attn_layer
@ -487,63 +642,57 @@ class NopadLlamaAttention(LlamaAttention):
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight)
attn_output = self.o_proj(attn_output)
return attn_output
# NOTE This will cause difference as out length increases.
class NopadLlamaMLP(LlamaMLP):
def __init__(
self,
config: LlamaConfig,
mlp_gproj_w: torch.Tensor = None,
mlp_uproj_w: torch.Tensor = None,
mlp_dproj_w: torch.Tensor = None,
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""This layer will replace the LlamaAttention.
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
Args:
config (LlamaConfig): Holding the Llama model config.
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__(config)
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
self.down_proj_weight = mlp_dproj_w
self.gate_proj = None
self.up_proj = None
self.down_proj = None
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
@staticmethod
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
"""Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
key = "qkv_weight"
k1 = "q_proj.weight"
k2 = "k_proj.weight"
k3 = "v_proj.weight"
q_w = state_dict[prefix + k1]
k_w = state_dict[prefix + k2]
v_w = state_dict[prefix + k3]
Args:
module (LlamaMLP): The origin LlamaMLP layer.
"""
config = module.config
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
mlp_layer = NopadLlamaMLP(
config=config,
mlp_gproj_w=mlp_gproj_w,
mlp_uproj_w=mlp_uproj_w,
mlp_dproj_w=mlp_dproj_w,
input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state[key]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
return mlp_layer
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
"""
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return torch.mm(act_out, self.down_proj_weight)
def extra_repr(self) -> str:
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"

View File

@ -1,4 +1,3 @@
from torch.nn import Parameter
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
from colossalai.inference.modeling.models.nopadding_llama import (
@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
llama_rmsnorm_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
@ -21,26 +21,69 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
decoder_attribute_replacement = {
"lm_head.weight": Parameter(self.model.lm_head.weight.transpose(0, 1), requires_grad=False),
}
policy[LlamaForCausalLM] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
)
else:
decoder_attribute_replacement = None
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadLlamaMLP,
),
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadLlamaAttention,
),
]
],
)
policy[LlamaForCausalLM] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": True}
)
],
)
# self.shard_config._infer()
self.append_or_create_method_replacement(
description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM
)

View File

@ -2,8 +2,12 @@
Utils for model inference
"""
import os
import re
from pathlib import Path
from typing import Optional, Tuple
import torch
from torch import nn
def init_to_get_rotary(self, base=10000, use_elem=False):
@ -49,3 +53,52 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
"""
Check whether the checkpoint has an index file.
Args:
checkpoint_path (str): path to the checkpoint.
Returns:
Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path)
"""
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.is_file():
# check if it is .index.json
reg = re.compile("(.*?).index((\..*)?).json")
if reg.fullmatch(checkpoint_path.name) is not None:
return True, checkpoint_path
else:
return False, None
elif checkpoint_path.is_dir():
index_files = list(checkpoint_path.glob("*.index.*json"))
for index_file in index_files:
if "safetensors" in index_file.__str__():
return True, index_file.__str__() # return the safetensors file first
if len(index_files) == 1:
return True, index_files[0]
else:
assert (
len(index_files) == 1
), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
return False, None
else:
raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
def get_model_size(model: nn.Module):
"""Calculates the total size of the model weights (including biases) in bytes.
Args:
model: The PyTorch model to analyze.
Returns:
The total size of the model weights in bytes.
"""
total_size = 0
for key, param in model.named_parameters():
total_size += param.element_size() * param.numel()
return total_size / (1024**3)

View File

@ -40,7 +40,7 @@ def check_inference_engine(use_cuda_graph=False, batch_size=32):
input_len = 1024
output_len = 128
do_sample = True
do_sample = False
top_p = 0.5
top_k = 50

View File

@ -3,24 +3,27 @@ import random
import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_engine=False, prompt_template=None):
def check_inference_engine(use_engine=False, prompt_template=None, do_sample=True, policy=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
@ -36,13 +39,19 @@ def check_inference_engine(use_engine=False, prompt_template=None):
]
output_len = 38
do_sample = True
do_sample = do_sample
top_p = 0.5
top_k = 50
if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
dtype="fp32",
use_cuda_kernel=True,
tp_size=dist.get_world_size(),
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
@ -69,20 +78,14 @@ def check_inference_engine(use_engine=False, prompt_template=None):
return outputs
@parameterize("prompt_template", [None, "llama"])
def check_output_consistency(prompt_template):
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
def run_engine(world_size, **kwargs):
manager = Manager()
result_list = manager.list([-1] * world_size) # Create a shared list
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
# clear singleton flash decoding tensors
FDIntermTensors._instances = {}
spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs)
return result_list[0]
@parameterize("num_layers", [1])
@parameterize("max_length", [100])
def check_spec_dec(num_layers, max_length):
torch.manual_seed(123)
@ -152,16 +155,47 @@ def check_spec_dec(num_layers, max_length):
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency()
check_spec_dec()
if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)
@parameterize("prompt_template", [None, "llama"])
@parameterize("do_sample", [False])
def test_tp_engine(prompt_template, do_sample):
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": NoPaddingLlamaModelInferPolicy(),
}
kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None}
colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
@parameterize("num_layers", [1])
@parameterize("max_length", [100])
def test_spec_dec(num_layers, max_length):
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
spawn(run_dist, 1)
test_tp_engine()
test_spec_dec()
if __name__ == "__main__":