mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support chatglm without layernorm
parent
cbb54d3202
commit
dad00c42aa
|
@ -396,17 +396,18 @@ class SelfAttention(torch.nn.Module):
|
||||||
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
||||||
self.qkv_hidden_size = (self.projection_size +
|
self.qkv_hidden_size = (self.projection_size +
|
||||||
2 * self.hidden_size_per_attention_head * config.multi_query_group_num)
|
2 * self.hidden_size_per_attention_head * config.multi_query_group_num)
|
||||||
|
<<<<<<< HEAD
|
||||||
self.query_key_value = nn.Linear(
|
self.query_key_value = nn.Linear(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
self.qkv_hidden_size,
|
self.qkv_hidden_size,
|
||||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
|
||||||
device=device,
|
device=device,
|
||||||
**_config_to_kwargs(config),
|
**_config_to_kwargs(config),
|
||||||
)
|
)
|
||||||
|
=======
|
||||||
self.core_attention = CoreAttention(config, self.layer_number)
|
self.query_key_value = nn.Linear(self.hidden_size,
|
||||||
|
self.qkv_hidden_size,
|
||||||
# Output.
|
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||||
|
<<<<<<< HEAD
|
||||||
self.dense = nn.Linear(
|
self.dense = nn.Linear(
|
||||||
self.projection_size,
|
self.projection_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
@ -414,6 +415,13 @@ class SelfAttention(torch.nn.Module):
|
||||||
device=device,
|
device=device,
|
||||||
**_config_to_kwargs(config),
|
**_config_to_kwargs(config),
|
||||||
)
|
)
|
||||||
|
=======
|
||||||
|
self.dense = nn.Linear(self.projection_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=config.add_bias_linear,
|
||||||
|
device=device,
|
||||||
|
**_config_to_kwargs(config))
|
||||||
|
>>>>>>> [shardformer] support chatglm without layernorm
|
||||||
|
|
||||||
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
||||||
if self.multi_query_attention:
|
if self.multi_query_attention:
|
||||||
|
@ -925,7 +933,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embedding(input_ids)
|
inputs_embeds = self.embedding(input_ids)
|
||||||
print(inputs_embeds)
|
|
||||||
|
|
||||||
if self.pre_seq_len is not None:
|
if self.pre_seq_len is not None:
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
|
Loading…
Reference in New Issue