[Inference]refactor baichuan (#5791)

* refactor baichuan

* remove unused code and add TODO for lazyinit
pull/5793/head
Runyu Lu 2024-06-11 10:52:01 +08:00 committed by GitHub
parent 77a219a082
commit c0948aff97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 110 deletions

View File

@ -15,25 +15,10 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
module.in_features = module.weight.size(1) module.in_features = module.weight.size(1)
module.out_features = module.weight.size(0) module.out_features = module.weight.size(0)
module.bias = None module.bias = None
module.weight.data = nn.functional.normalize(module.weight) module.weight.data = nn.functional.normalize(
module.weight
return Linear1D_Col.from_native_module( ) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
module, # So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
process_group,
*args,
**kwargs,
)
class BaichuanWpackLinear1D_Col(Linear1D_Col):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
in_features = module.in_features * 3
out_features = module.out_features // 3
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
module.bias = None
return Linear1D_Col.from_native_module( return Linear1D_Col.from_native_module(
module, module,

View File

@ -1,11 +1,11 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py # This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import itertools
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from colossalai.inference.config import ModelShardInferenceConfig from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
@ -16,7 +16,7 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import rms_layernorm from colossalai.kernel.triton import rms_layernorm
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor from colossalai.tensor.d_tensor import is_distributed_tensor
inference_ops = InferenceOpsLoader().load() inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__) logger = get_dist_logger(__name__)
@ -55,24 +55,19 @@ class NopadBaichuanAttention(ParallelModule):
def __init__( def __init__(
self, self,
config, config,
attn_qproj_w: torch.Tensor = None, W_pack: ParallelModule = None,
attn_kproj_w: torch.Tensor = None,
attn_vproj_w: torch.Tensor = None,
attn_oproj: ParallelModule = None, attn_oproj: ParallelModule = None,
num_heads: int = None, num_heads: int = None,
hidden_size: int = None, hidden_size: int = None,
model_shard_infer_config: ModelShardInferenceConfig = None, model_shard_infer_config: ModelShardInferenceConfig = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
helper_layout: Layout = None,
): ):
"""This layer will replace the BaichuanAttention. """This layer will replace the BaichuanAttention.
Args: Args:
config (BaichuanConfig): Holding the Baichuan model config. config (BaichuanConfig): Holding the Baichuan model config.
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None. W_pack (ParallelModule, optional): The packed weight. Defaults to None.
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None. attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
""" """
ParallelModule.__init__(self) ParallelModule.__init__(self)
self.o_proj = attn_oproj self.o_proj = attn_oproj
@ -82,10 +77,7 @@ class NopadBaichuanAttention(ParallelModule):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.process_group = process_group self.process_group = process_group
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] self.W_pack = W_pack
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
self.helper_layout = helper_layout
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
self.attention_backend = get_attention_backend(model_shard_infer_config) self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
@ -96,9 +88,9 @@ class NopadBaichuanAttention(ParallelModule):
if config.hidden_size == 5120: if config.hidden_size == 5120:
slopes_start = self.process_group.rank() * num_heads slopes_start = self.process_group.rank() * num_heads
self.use_alibi_attn = True self.use_alibi_attn = True
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[ self.alibi_slopes = get_alibi_slopes(
slopes_start : slopes_start + num_heads config.num_attention_heads, device=get_accelerator().get_current_device()
].contiguous() )[slopes_start : slopes_start + num_heads].contiguous()
self.alibi_slopes = nn.Parameter(self.alibi_slopes) self.alibi_slopes = nn.Parameter(self.alibi_slopes)
@staticmethod @staticmethod
@ -112,78 +104,22 @@ class NopadBaichuanAttention(ParallelModule):
""" """
config = module.config config = module.config
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1) W_pack = module.W_pack
attn_qproj_w = q_proj_w
attn_kproj_w = k_proj_w
attn_vproj_w = v_proj_w
attn_oproj = module.o_proj attn_oproj = module.o_proj
model_shard_infer_config = kwargs.get("model_shard_infer_config", None) model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
helper_layout = (
module.W_pack.weight.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
attn_layer = NopadBaichuanAttention( attn_layer = NopadBaichuanAttention(
config=config, config=config,
attn_qproj_w=attn_qproj_w, W_pack=W_pack,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj=attn_oproj, attn_oproj=attn_oproj,
model_shard_infer_config=model_shard_infer_config, model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads, num_heads=module.num_heads,
hidden_size=module.hidden_size, hidden_size=module.hidden_size,
process_group=process_group, process_group=process_group,
helper_layout=helper_layout,
) )
return attn_layer return attn_layer
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
key = "qkv_weight"
qkv_w = state_dict[prefix + "W_pack.weight"]
in_features = qkv_w.size(1)
out_features = qkv_w.size(0) // 3
qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3)
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec)
qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1)
input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state[key]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -220,13 +156,13 @@ class NopadBaichuanAttention(ParallelModule):
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
""" """
token_nums = hidden_states.size(0) token_nums = hidden_states.size(0)
# fused qkv
hidden_states = hidden_states.expand(3, -1, -1) proj = self.W_pack(hidden_states)
query_states, key_states, value_states = ( proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) query_states = proj[0].view(token_nums, self.num_heads, self.head_dim)
) key_states = proj[1].view(token_nums, self.num_heads, self.head_dim)
value_states = proj[2].view(token_nums, self.num_heads, self.head_dim)
block_size = k_cache.size(-2) block_size = k_cache.size(-2)
@ -279,9 +215,6 @@ class NopadBaichuanAttention(ParallelModule):
return attn_output return attn_output
def extra_repr(self) -> str:
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
# NOTE This will cause difference as out length increases. # NOTE This will cause difference as out length increases.
class NopadBaichuanMLP(NopadLlamaMLP): class NopadBaichuanMLP(NopadLlamaMLP):

View File

@ -1,8 +1,5 @@
from colossalai.inference.config import RPC_PARAM from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.layers.baichuan_tp_linear import ( from colossalai.inference.modeling.layers.baichuan_tp_linear import BaichuanLMHeadLinear1D_Col
BaichuanLMHeadLinear1D_Col,
BaichuanWpackLinear1D_Col,
)
from colossalai.inference.modeling.models.nopadding_baichuan import ( from colossalai.inference.modeling.models.nopadding_baichuan import (
NopadBaichuanAttention, NopadBaichuanAttention,
NopadBaichuanMLP, NopadBaichuanMLP,
@ -14,7 +11,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
llama_model_forward, llama_model_forward,
) )
from colossalai.inference.utils import init_to_get_rotary from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer import FusedLinear1D_Col, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
@ -60,8 +57,7 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
target_module=NopadBaichuanMLP, target_module=NopadBaichuanMLP,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.W_pack", suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
target_module=BaichuanWpackLinear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",