mirror of https://github.com/hpcaitech/ColossalAI
[example] change qkv processing (#870)
parent
96211c2cc8
commit
2883040286
|
@ -89,13 +89,14 @@ class GPTSelfAttention(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None):
|
def forward(self, x, attention_mask=None):
|
||||||
qkv = self.query_key_value(x)
|
qkv = self.query_key_value(x)
|
||||||
all_head_size = qkv.shape[-1] // 3
|
|
||||||
num_attention_heads = divide(all_head_size, self.attention_head_size)
|
|
||||||
new_qkv_shape = qkv.shape[:-1] + \
|
|
||||||
(num_attention_heads, 3 * self.attention_head_size)
|
|
||||||
qkv = qkv.view(new_qkv_shape)
|
|
||||||
qkv = qkv.permute((0, 2, 1, 3))
|
|
||||||
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
||||||
|
all_head_size = q.shape[-1]
|
||||||
|
num_attention_heads = divide(all_head_size, self.attention_head_size)
|
||||||
|
new_shape = q.shape[:-1] + \
|
||||||
|
(num_attention_heads, self.attention_head_size)
|
||||||
|
q = q.view(new_shape).permute((0, 2, 1, 3)).contiguous()
|
||||||
|
k = k.view(new_shape).permute((0, 2, 1, 3)).contiguous()
|
||||||
|
v = v.view(new_shape).permute((0, 2, 1, 3)).contiguous()
|
||||||
|
|
||||||
x = torch.matmul(q, k.transpose(-1, -2))
|
x = torch.matmul(q, k.transpose(-1, -2))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue