[shardformer] pre-commit check files

pull/4445/head
klhhhhh 2023-07-19 11:39:59 +08:00 committed by Hongxin Liu
parent 91850fe984
commit 4da05052f4
1 changed files with 6 additions and 13 deletions

View File

@ -396,18 +396,17 @@ class SelfAttention(torch.nn.Module):
self.num_multi_query_groups_per_partition = config.multi_query_group_num
self.qkv_hidden_size = (self.projection_size +
2 * self.hidden_size_per_attention_head * config.multi_query_group_num)
<<<<<<< HEAD
self.query_key_value = nn.Linear(
config.hidden_size,
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device,
**_config_to_kwargs(config),
)
=======
self.query_key_value = nn.Linear(self.hidden_size,
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
<<<<<<< HEAD
self.core_attention = CoreAttention(config, self.layer_number)
# Output.
self.dense = nn.Linear(
self.projection_size,
config.hidden_size,
@ -415,13 +414,6 @@ class SelfAttention(torch.nn.Module):
device=device,
**_config_to_kwargs(config),
)
=======
self.dense = nn.Linear(self.projection_size,
config.hidden_size,
bias=config.add_bias_linear,
device=device,
**_config_to_kwargs(config))
>>>>>>> [shardformer] support chatglm without layernorm
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
if self.multi_query_attention:
@ -989,6 +981,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
def quantize(self, weight_bit_width: int):
from .quantization import quantize
quantize(self.encoder, weight_bit_width)
return self