|
|
|
@ -1,11 +1,11 @@
|
|
|
|
|
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py |
|
|
|
|
import itertools |
|
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.nn as nn |
|
|
|
|
from torch.distributed import ProcessGroup |
|
|
|
|
|
|
|
|
|
from colossalai.accelerator import get_accelerator |
|
|
|
|
from colossalai.inference.config import ModelShardInferenceConfig |
|
|
|
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors |
|
|
|
|
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend |
|
|
|
@ -16,7 +16,7 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
|
|
|
|
from colossalai.kernel.triton import rms_layernorm |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from colossalai.shardformer.layer.parallel_module import ParallelModule |
|
|
|
|
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor |
|
|
|
|
from colossalai.tensor.d_tensor import is_distributed_tensor |
|
|
|
|
|
|
|
|
|
inference_ops = InferenceOpsLoader().load() |
|
|
|
|
logger = get_dist_logger(__name__) |
|
|
|
@ -55,24 +55,19 @@ class NopadBaichuanAttention(ParallelModule):
|
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
config, |
|
|
|
|
attn_qproj_w: torch.Tensor = None, |
|
|
|
|
attn_kproj_w: torch.Tensor = None, |
|
|
|
|
attn_vproj_w: torch.Tensor = None, |
|
|
|
|
W_pack: ParallelModule = None, |
|
|
|
|
attn_oproj: ParallelModule = None, |
|
|
|
|
num_heads: int = None, |
|
|
|
|
hidden_size: int = None, |
|
|
|
|
model_shard_infer_config: ModelShardInferenceConfig = None, |
|
|
|
|
process_group: ProcessGroup = None, |
|
|
|
|
helper_layout: Layout = None, |
|
|
|
|
): |
|
|
|
|
"""This layer will replace the BaichuanAttention. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
config (BaichuanConfig): Holding the Baichuan model config. |
|
|
|
|
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None. |
|
|
|
|
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None. |
|
|
|
|
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None. |
|
|
|
|
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None. |
|
|
|
|
W_pack (ParallelModule, optional): The packed weight. Defaults to None. |
|
|
|
|
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None. |
|
|
|
|
""" |
|
|
|
|
ParallelModule.__init__(self) |
|
|
|
|
self.o_proj = attn_oproj |
|
|
|
@ -82,10 +77,7 @@ class NopadBaichuanAttention(ParallelModule):
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
|
|
self.process_group = process_group |
|
|
|
|
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] |
|
|
|
|
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) |
|
|
|
|
|
|
|
|
|
self.helper_layout = helper_layout |
|
|
|
|
self.W_pack = W_pack |
|
|
|
|
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel |
|
|
|
|
self.attention_backend = get_attention_backend(model_shard_infer_config) |
|
|
|
|
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) |
|
|
|
@ -96,9 +88,9 @@ class NopadBaichuanAttention(ParallelModule):
|
|
|
|
|
if config.hidden_size == 5120: |
|
|
|
|
slopes_start = self.process_group.rank() * num_heads |
|
|
|
|
self.use_alibi_attn = True |
|
|
|
|
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[ |
|
|
|
|
slopes_start : slopes_start + num_heads |
|
|
|
|
].contiguous() |
|
|
|
|
self.alibi_slopes = get_alibi_slopes( |
|
|
|
|
config.num_attention_heads, device=get_accelerator().get_current_device() |
|
|
|
|
)[slopes_start : slopes_start + num_heads].contiguous() |
|
|
|
|
self.alibi_slopes = nn.Parameter(self.alibi_slopes) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
@ -112,78 +104,22 @@ class NopadBaichuanAttention(ParallelModule):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
config = module.config |
|
|
|
|
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1) |
|
|
|
|
|
|
|
|
|
attn_qproj_w = q_proj_w |
|
|
|
|
attn_kproj_w = k_proj_w |
|
|
|
|
attn_vproj_w = v_proj_w |
|
|
|
|
W_pack = module.W_pack |
|
|
|
|
attn_oproj = module.o_proj |
|
|
|
|
model_shard_infer_config = kwargs.get("model_shard_infer_config", None) |
|
|
|
|
|
|
|
|
|
helper_layout = ( |
|
|
|
|
module.W_pack.weight.dist_layout |
|
|
|
|
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) |
|
|
|
|
|
|
|
|
|
attn_layer = NopadBaichuanAttention( |
|
|
|
|
config=config, |
|
|
|
|
attn_qproj_w=attn_qproj_w, |
|
|
|
|
attn_kproj_w=attn_kproj_w, |
|
|
|
|
attn_vproj_w=attn_vproj_w, |
|
|
|
|
W_pack=W_pack, |
|
|
|
|
attn_oproj=attn_oproj, |
|
|
|
|
model_shard_infer_config=model_shard_infer_config, |
|
|
|
|
num_heads=module.num_heads, |
|
|
|
|
hidden_size=module.hidden_size, |
|
|
|
|
process_group=process_group, |
|
|
|
|
helper_layout=helper_layout, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return attn_layer |
|
|
|
|
|
|
|
|
|
def _load_from_state_dict( |
|
|
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
|
|
|
): |
|
|
|
|
for hook in self._load_state_dict_pre_hooks.values(): |
|
|
|
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
|
|
|
|
|
|
|
|
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} |
|
|
|
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) |
|
|
|
|
local_state = {k: v for k, v in local_name_params if v is not None} |
|
|
|
|
|
|
|
|
|
key = "qkv_weight" |
|
|
|
|
qkv_w = state_dict[prefix + "W_pack.weight"] |
|
|
|
|
|
|
|
|
|
in_features = qkv_w.size(1) |
|
|
|
|
out_features = qkv_w.size(0) // 3 |
|
|
|
|
|
|
|
|
|
qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3) |
|
|
|
|
|
|
|
|
|
device_mesh = self.helper_layout.device_mesh |
|
|
|
|
sharding_spec = self.helper_layout.sharding_spec |
|
|
|
|
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec) |
|
|
|
|
|
|
|
|
|
qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1) |
|
|
|
|
input_param = nn.Parameter( |
|
|
|
|
qkv_w |
|
|
|
|
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) |
|
|
|
|
|
|
|
|
|
param = local_state[key] |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
param.copy_(input_param) |
|
|
|
|
except Exception as ex: |
|
|
|
|
error_msgs.append( |
|
|
|
|
'While copying the parameter named "{}", ' |
|
|
|
|
"whose dimensions in the model are {} and " |
|
|
|
|
"whose dimensions in the checkpoint are {}, " |
|
|
|
|
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
strict = False # to avoid unexpected_keys |
|
|
|
|
super()._load_from_state_dict( |
|
|
|
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
|
self, |
|
|
|
|
hidden_states: torch.Tensor, |
|
|
|
@ -220,13 +156,13 @@ class NopadBaichuanAttention(ParallelModule):
|
|
|
|
|
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. |
|
|
|
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
token_nums = hidden_states.size(0) |
|
|
|
|
# fused qkv |
|
|
|
|
hidden_states = hidden_states.expand(3, -1, -1) |
|
|
|
|
query_states, key_states, value_states = ( |
|
|
|
|
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
proj = self.W_pack(hidden_states) |
|
|
|
|
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) |
|
|
|
|
query_states = proj[0].view(token_nums, self.num_heads, self.head_dim) |
|
|
|
|
key_states = proj[1].view(token_nums, self.num_heads, self.head_dim) |
|
|
|
|
value_states = proj[2].view(token_nums, self.num_heads, self.head_dim) |
|
|
|
|
|
|
|
|
|
block_size = k_cache.size(-2) |
|
|
|
|
|
|
|
|
@ -279,9 +215,6 @@ class NopadBaichuanAttention(ParallelModule):
|
|
|
|
|
|
|
|
|
|
return attn_output |
|
|
|
|
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
|
|
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE This will cause difference as out length increases. |
|
|
|
|
class NopadBaichuanMLP(NopadLlamaMLP): |
|
|
|
|