diff --git a/colossalai/inference/modeling/layers/baichuan_tp_linear.py b/colossalai/inference/modeling/layers/baichuan_tp_linear.py index e050dd71c..50806a14b 100644 --- a/colossalai/inference/modeling/layers/baichuan_tp_linear.py +++ b/colossalai/inference/modeling/layers/baichuan_tp_linear.py @@ -15,25 +15,10 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col): module.in_features = module.weight.size(1) module.out_features = module.weight.size(0) module.bias = None - module.weight.data = nn.functional.normalize(module.weight) - - return Linear1D_Col.from_native_module( - module, - process_group, - *args, - **kwargs, - ) - - -class BaichuanWpackLinear1D_Col(Linear1D_Col): - @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: - in_features = module.in_features * 3 - out_features = module.out_features // 3 - module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features) - module.bias = None + module.weight.data = nn.functional.normalize( + module.weight + ) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight. + # So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue. return Linear1D_Col.from_native_module( module, diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index f10ef6e3c..ba3e7b4e8 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -1,11 +1,11 @@ # This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py -import itertools from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.distributed import ProcessGroup +from colossalai.accelerator import get_accelerator from colossalai.inference.config import ModelShardInferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend @@ -16,7 +16,7 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import rms_layernorm from colossalai.logging import get_dist_logger from colossalai.shardformer.layer.parallel_module import ParallelModule -from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor +from colossalai.tensor.d_tensor import is_distributed_tensor inference_ops = InferenceOpsLoader().load() logger = get_dist_logger(__name__) @@ -55,24 +55,19 @@ class NopadBaichuanAttention(ParallelModule): def __init__( self, config, - attn_qproj_w: torch.Tensor = None, - attn_kproj_w: torch.Tensor = None, - attn_vproj_w: torch.Tensor = None, + W_pack: ParallelModule = None, attn_oproj: ParallelModule = None, num_heads: int = None, hidden_size: int = None, model_shard_infer_config: ModelShardInferenceConfig = None, process_group: ProcessGroup = None, - helper_layout: Layout = None, ): """This layer will replace the BaichuanAttention. Args: config (BaichuanConfig): Holding the Baichuan model config. - attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None. - attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None. - attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None. - attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None. + W_pack (ParallelModule, optional): The packed weight. Defaults to None. + attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None. """ ParallelModule.__init__(self) self.o_proj = attn_oproj @@ -82,10 +77,7 @@ class NopadBaichuanAttention(ParallelModule): self.hidden_size = hidden_size self.head_dim = self.hidden_size // self.num_heads self.process_group = process_group - qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] - self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) - - self.helper_layout = helper_layout + self.W_pack = W_pack self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel self.attention_backend = get_attention_backend(model_shard_infer_config) self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) @@ -96,9 +88,9 @@ class NopadBaichuanAttention(ParallelModule): if config.hidden_size == 5120: slopes_start = self.process_group.rank() * num_heads self.use_alibi_attn = True - self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[ - slopes_start : slopes_start + num_heads - ].contiguous() + self.alibi_slopes = get_alibi_slopes( + config.num_attention_heads, device=get_accelerator().get_current_device() + )[slopes_start : slopes_start + num_heads].contiguous() self.alibi_slopes = nn.Parameter(self.alibi_slopes) @staticmethod @@ -112,78 +104,22 @@ class NopadBaichuanAttention(ParallelModule): """ config = module.config - q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1) - - attn_qproj_w = q_proj_w - attn_kproj_w = k_proj_w - attn_vproj_w = v_proj_w + W_pack = module.W_pack attn_oproj = module.o_proj model_shard_infer_config = kwargs.get("model_shard_infer_config", None) - helper_layout = ( - module.W_pack.weight.dist_layout - ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) - attn_layer = NopadBaichuanAttention( config=config, - attn_qproj_w=attn_qproj_w, - attn_kproj_w=attn_kproj_w, - attn_vproj_w=attn_vproj_w, + W_pack=W_pack, attn_oproj=attn_oproj, model_shard_infer_config=model_shard_infer_config, num_heads=module.num_heads, hidden_size=module.hidden_size, process_group=process_group, - helper_layout=helper_layout, ) return attn_layer - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} - - key = "qkv_weight" - qkv_w = state_dict[prefix + "W_pack.weight"] - - in_features = qkv_w.size(1) - out_features = qkv_w.size(0) // 3 - - qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3) - - device_mesh = self.helper_layout.device_mesh - sharding_spec = self.helper_layout.sharding_spec - qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec) - - qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1) - input_param = nn.Parameter( - qkv_w - ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) - - param = local_state[key] - - try: - with torch.no_grad(): - param.copy_(input_param) - except Exception as ex: - error_msgs.append( - 'While copying the parameter named "{}", ' - "whose dimensions in the model are {} and " - "whose dimensions in the checkpoint are {}, " - "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) - ) - - strict = False # to avoid unexpected_keys - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - def forward( self, hidden_states: torch.Tensor, @@ -220,13 +156,13 @@ class NopadBaichuanAttention(ParallelModule): cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ - token_nums = hidden_states.size(0) - # fused qkv - hidden_states = hidden_states.expand(3, -1, -1) - query_states, key_states, value_states = ( - torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) - ) + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(token_nums, self.num_heads, self.head_dim) + key_states = proj[1].view(token_nums, self.num_heads, self.head_dim) + value_states = proj[2].view(token_nums, self.num_heads, self.head_dim) block_size = k_cache.size(-2) @@ -279,9 +215,6 @@ class NopadBaichuanAttention(ParallelModule): return attn_output - def extra_repr(self) -> str: - return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False" - # NOTE This will cause difference as out length increases. class NopadBaichuanMLP(NopadLlamaMLP): diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py index b28c2fce8..37b5062e8 100644 --- a/colossalai/inference/modeling/policy/nopadding_baichuan.py +++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py @@ -1,8 +1,5 @@ from colossalai.inference.config import RPC_PARAM -from colossalai.inference.modeling.layers.baichuan_tp_linear import ( - BaichuanLMHeadLinear1D_Col, - BaichuanWpackLinear1D_Col, -) +from colossalai.inference.modeling.layers.baichuan_tp_linear import BaichuanLMHeadLinear1D_Col from colossalai.inference.modeling.models.nopadding_baichuan import ( NopadBaichuanAttention, NopadBaichuanMLP, @@ -14,7 +11,7 @@ from colossalai.inference.modeling.models.nopadding_llama import ( llama_model_forward, ) from colossalai.inference.utils import init_to_get_rotary -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import FusedLinear1D_Col, Linear1D_Col, Linear1D_Row from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy @@ -60,8 +57,7 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): target_module=NopadBaichuanMLP, ), SubModuleReplacementDescription( - suffix="self_attn.W_pack", - target_module=BaichuanWpackLinear1D_Col, + suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3} ), SubModuleReplacementDescription( suffix="self_attn.o_proj",