mirror of https://github.com/hpcaitech/ColossalAI
[example] change qkv processing (#870)
parent
96211c2cc8
commit
2883040286
|
@ -89,13 +89,14 @@ class GPTSelfAttention(nn.Module):
|
|||
|
||||
def forward(self, x, attention_mask=None):
|
||||
qkv = self.query_key_value(x)
|
||||
all_head_size = qkv.shape[-1] // 3
|
||||
num_attention_heads = divide(all_head_size, self.attention_head_size)
|
||||
new_qkv_shape = qkv.shape[:-1] + \
|
||||
(num_attention_heads, 3 * self.attention_head_size)
|
||||
qkv = qkv.view(new_qkv_shape)
|
||||
qkv = qkv.permute((0, 2, 1, 3))
|
||||
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
||||
all_head_size = q.shape[-1]
|
||||
num_attention_heads = divide(all_head_size, self.attention_head_size)
|
||||
new_shape = q.shape[:-1] + \
|
||||
(num_attention_heads, self.attention_head_size)
|
||||
q = q.view(new_shape).permute((0, 2, 1, 3)).contiguous()
|
||||
k = k.view(new_shape).permute((0, 2, 1, 3)).contiguous()
|
||||
v = v.view(new_shape).permute((0, 2, 1, 3)).contiguous()
|
||||
|
||||
x = torch.matmul(q, k.transpose(-1, -2))
|
||||
|
||||
|
|
Loading…
Reference in New Issue