mirror of https://github.com/hpcaitech/ColossalAI
40 lines
1.7 KiB
Python
40 lines
1.7 KiB
Python
import torch
|
|
|
|
|
|
def forward_fn():
|
|
|
|
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
|
batch_size, height, width, _ = hidden_states.shape
|
|
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
|
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
|
|
-1).permute(2, 0, 3, 1, 4))
|
|
# q, k, v with shape (batch_size * nHead, height * width, channel)
|
|
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
|
|
|
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
|
|
|
if self.use_rel_pos:
|
|
attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w,
|
|
(height, width), (height, width))
|
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
|
|
|
# replace dropout process with added DropoutForParallelInput layer
|
|
# origin code:
|
|
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
attn_probs = self.dropout_layer(attn_weights)
|
|
|
|
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
|
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
|
|
|
attn_output = self.proj(attn_output)
|
|
|
|
if output_attentions:
|
|
outputs = (attn_output, attn_weights)
|
|
else:
|
|
outputs = (attn_output, None)
|
|
|
|
return outputs
|
|
|
|
return forward
|