[example] change qkv processing (#870)

pull/877/head
LuGY 2022-04-26 13:33:27 +08:00 committed by GitHub
parent 96211c2cc8
commit 2883040286
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 6 deletions

View File

@ -89,13 +89,14 @@ class GPTSelfAttention(nn.Module):
def forward(self, x, attention_mask=None):
qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3
num_attention_heads = divide(all_head_size, self.attention_head_size)
new_qkv_shape = qkv.shape[:-1] + \
(num_attention_heads, 3 * self.attention_head_size)
qkv = qkv.view(new_qkv_shape)
qkv = qkv.permute((0, 2, 1, 3))
q, k, v = torch.chunk(qkv, 3, dim=-1)
all_head_size = q.shape[-1]
num_attention_heads = divide(all_head_size, self.attention_head_size)
new_shape = q.shape[:-1] + \
(num_attention_heads, self.attention_head_size)
q = q.view(new_shape).permute((0, 2, 1, 3)).contiguous()
k = k.view(new_shape).permute((0, 2, 1, 3)).contiguous()
v = v.view(new_shape).permute((0, 2, 1, 3)).contiguous()
x = torch.matmul(q, k.transpose(-1, -2))