mirror of https://github.com/hpcaitech/ColossalAI
[Inference]refactor baichuan (#5791)
* refactor baichuan * remove unused code and add TODO for lazyinitpull/5793/head
parent
77a219a082
commit
c0948aff97
|
@ -15,25 +15,10 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
|
||||||
module.in_features = module.weight.size(1)
|
module.in_features = module.weight.size(1)
|
||||||
module.out_features = module.weight.size(0)
|
module.out_features = module.weight.size(0)
|
||||||
module.bias = None
|
module.bias = None
|
||||||
module.weight.data = nn.functional.normalize(module.weight)
|
module.weight.data = nn.functional.normalize(
|
||||||
|
module.weight
|
||||||
return Linear1D_Col.from_native_module(
|
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
|
||||||
module,
|
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
|
||||||
process_group,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaichuanWpackLinear1D_Col(Linear1D_Col):
|
|
||||||
@staticmethod
|
|
||||||
def from_native_module(
|
|
||||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
||||||
) -> ParallelModule:
|
|
||||||
in_features = module.in_features * 3
|
|
||||||
out_features = module.out_features // 3
|
|
||||||
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
|
|
||||||
module.bias = None
|
|
||||||
|
|
||||||
return Linear1D_Col.from_native_module(
|
return Linear1D_Col.from_native_module(
|
||||||
module,
|
module,
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
||||||
import itertools
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.inference.config import ModelShardInferenceConfig
|
from colossalai.inference.config import ModelShardInferenceConfig
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
||||||
|
@ -16,7 +16,7 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import rms_layernorm
|
from colossalai.kernel.triton import rms_layernorm
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||||
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
|
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||||
|
|
||||||
inference_ops = InferenceOpsLoader().load()
|
inference_ops = InferenceOpsLoader().load()
|
||||||
logger = get_dist_logger(__name__)
|
logger = get_dist_logger(__name__)
|
||||||
|
@ -55,24 +55,19 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
attn_qproj_w: torch.Tensor = None,
|
W_pack: ParallelModule = None,
|
||||||
attn_kproj_w: torch.Tensor = None,
|
|
||||||
attn_vproj_w: torch.Tensor = None,
|
|
||||||
attn_oproj: ParallelModule = None,
|
attn_oproj: ParallelModule = None,
|
||||||
num_heads: int = None,
|
num_heads: int = None,
|
||||||
hidden_size: int = None,
|
hidden_size: int = None,
|
||||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
helper_layout: Layout = None,
|
|
||||||
):
|
):
|
||||||
"""This layer will replace the BaichuanAttention.
|
"""This layer will replace the BaichuanAttention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (BaichuanConfig): Holding the Baichuan model config.
|
config (BaichuanConfig): Holding the Baichuan model config.
|
||||||
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
|
W_pack (ParallelModule, optional): The packed weight. Defaults to None.
|
||||||
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
|
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
|
||||||
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
|
|
||||||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
ParallelModule.__init__(self)
|
ParallelModule.__init__(self)
|
||||||
self.o_proj = attn_oproj
|
self.o_proj = attn_oproj
|
||||||
|
@ -82,10 +77,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
self.W_pack = W_pack
|
||||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
|
||||||
|
|
||||||
self.helper_layout = helper_layout
|
|
||||||
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
|
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
|
||||||
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||||
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||||
|
@ -96,9 +88,9 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
if config.hidden_size == 5120:
|
if config.hidden_size == 5120:
|
||||||
slopes_start = self.process_group.rank() * num_heads
|
slopes_start = self.process_group.rank() * num_heads
|
||||||
self.use_alibi_attn = True
|
self.use_alibi_attn = True
|
||||||
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
|
self.alibi_slopes = get_alibi_slopes(
|
||||||
slopes_start : slopes_start + num_heads
|
config.num_attention_heads, device=get_accelerator().get_current_device()
|
||||||
].contiguous()
|
)[slopes_start : slopes_start + num_heads].contiguous()
|
||||||
self.alibi_slopes = nn.Parameter(self.alibi_slopes)
|
self.alibi_slopes = nn.Parameter(self.alibi_slopes)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -112,78 +104,22 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config = module.config
|
config = module.config
|
||||||
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1)
|
W_pack = module.W_pack
|
||||||
|
|
||||||
attn_qproj_w = q_proj_w
|
|
||||||
attn_kproj_w = k_proj_w
|
|
||||||
attn_vproj_w = v_proj_w
|
|
||||||
attn_oproj = module.o_proj
|
attn_oproj = module.o_proj
|
||||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||||
|
|
||||||
helper_layout = (
|
|
||||||
module.W_pack.weight.dist_layout
|
|
||||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
|
||||||
|
|
||||||
attn_layer = NopadBaichuanAttention(
|
attn_layer = NopadBaichuanAttention(
|
||||||
config=config,
|
config=config,
|
||||||
attn_qproj_w=attn_qproj_w,
|
W_pack=W_pack,
|
||||||
attn_kproj_w=attn_kproj_w,
|
|
||||||
attn_vproj_w=attn_vproj_w,
|
|
||||||
attn_oproj=attn_oproj,
|
attn_oproj=attn_oproj,
|
||||||
model_shard_infer_config=model_shard_infer_config,
|
model_shard_infer_config=model_shard_infer_config,
|
||||||
num_heads=module.num_heads,
|
num_heads=module.num_heads,
|
||||||
hidden_size=module.hidden_size,
|
hidden_size=module.hidden_size,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
helper_layout=helper_layout,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return attn_layer
|
return attn_layer
|
||||||
|
|
||||||
def _load_from_state_dict(
|
|
||||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
||||||
):
|
|
||||||
for hook in self._load_state_dict_pre_hooks.values():
|
|
||||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
||||||
|
|
||||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
|
||||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
|
||||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
||||||
|
|
||||||
key = "qkv_weight"
|
|
||||||
qkv_w = state_dict[prefix + "W_pack.weight"]
|
|
||||||
|
|
||||||
in_features = qkv_w.size(1)
|
|
||||||
out_features = qkv_w.size(0) // 3
|
|
||||||
|
|
||||||
qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3)
|
|
||||||
|
|
||||||
device_mesh = self.helper_layout.device_mesh
|
|
||||||
sharding_spec = self.helper_layout.sharding_spec
|
|
||||||
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec)
|
|
||||||
|
|
||||||
qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1)
|
|
||||||
input_param = nn.Parameter(
|
|
||||||
qkv_w
|
|
||||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
|
||||||
|
|
||||||
param = local_state[key]
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
param.copy_(input_param)
|
|
||||||
except Exception as ex:
|
|
||||||
error_msgs.append(
|
|
||||||
'While copying the parameter named "{}", '
|
|
||||||
"whose dimensions in the model are {} and "
|
|
||||||
"whose dimensions in the checkpoint are {}, "
|
|
||||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
|
||||||
)
|
|
||||||
|
|
||||||
strict = False # to avoid unexpected_keys
|
|
||||||
super()._load_from_state_dict(
|
|
||||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -220,13 +156,13 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
||||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token_nums = hidden_states.size(0)
|
token_nums = hidden_states.size(0)
|
||||||
# fused qkv
|
|
||||||
hidden_states = hidden_states.expand(3, -1, -1)
|
proj = self.W_pack(hidden_states)
|
||||||
query_states, key_states, value_states = (
|
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||||
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
query_states = proj[0].view(token_nums, self.num_heads, self.head_dim)
|
||||||
)
|
key_states = proj[1].view(token_nums, self.num_heads, self.head_dim)
|
||||||
|
value_states = proj[2].view(token_nums, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
|
@ -279,9 +215,6 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE This will cause difference as out length increases.
|
# NOTE This will cause difference as out length increases.
|
||||||
class NopadBaichuanMLP(NopadLlamaMLP):
|
class NopadBaichuanMLP(NopadLlamaMLP):
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
from colossalai.inference.config import RPC_PARAM
|
from colossalai.inference.config import RPC_PARAM
|
||||||
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
|
from colossalai.inference.modeling.layers.baichuan_tp_linear import BaichuanLMHeadLinear1D_Col
|
||||||
BaichuanLMHeadLinear1D_Col,
|
|
||||||
BaichuanWpackLinear1D_Col,
|
|
||||||
)
|
|
||||||
from colossalai.inference.modeling.models.nopadding_baichuan import (
|
from colossalai.inference.modeling.models.nopadding_baichuan import (
|
||||||
NopadBaichuanAttention,
|
NopadBaichuanAttention,
|
||||||
NopadBaichuanMLP,
|
NopadBaichuanMLP,
|
||||||
|
@ -14,7 +11,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
||||||
llama_model_forward,
|
llama_model_forward,
|
||||||
)
|
)
|
||||||
from colossalai.inference.utils import init_to_get_rotary
|
from colossalai.inference.utils import init_to_get_rotary
|
||||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import FusedLinear1D_Col, Linear1D_Col, Linear1D_Row
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
|
@ -60,8 +57,7 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
||||||
target_module=NopadBaichuanMLP,
|
target_module=NopadBaichuanMLP,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.W_pack",
|
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
|
||||||
target_module=BaichuanWpackLinear1D_Col,
|
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
|
|
Loading…
Reference in New Issue