[shardformer] support chatglm without layernorm

pull/4445/head
klhhhhh 2023-07-14 18:10:52 +08:00 committed by Hongxin Liu
parent cbb54d3202
commit dad00c42aa
1 changed files with 13 additions and 6 deletions

View File

@ -396,17 +396,18 @@ class SelfAttention(torch.nn.Module):
self.num_multi_query_groups_per_partition = config.multi_query_group_num
self.qkv_hidden_size = (self.projection_size +
2 * self.hidden_size_per_attention_head * config.multi_query_group_num)
<<<<<<< HEAD
self.query_key_value = nn.Linear(
config.hidden_size,
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device,
**_config_to_kwargs(config),
)
self.core_attention = CoreAttention(config, self.layer_number)
# Output.
=======
self.query_key_value = nn.Linear(self.hidden_size,
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
<<<<<<< HEAD
self.dense = nn.Linear(
self.projection_size,
config.hidden_size,
@ -414,6 +415,13 @@ class SelfAttention(torch.nn.Module):
device=device,
**_config_to_kwargs(config),
)
=======
self.dense = nn.Linear(self.projection_size,
self.hidden_size,
bias=config.add_bias_linear,
device=device,
**_config_to_kwargs(config))
>>>>>>> [shardformer] support chatglm without layernorm
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
if self.multi_query_attention:
@ -925,7 +933,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
print(inputs_embeds)
if self.pre_seq_len is not None:
if past_key_values is None: