mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
1.7 KiB
43 lines
1.7 KiB
import torch
|
|
|
|
|
|
def forward_fn():
|
|
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
|
batch_size, height, width, _ = hidden_states.shape
|
|
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
|
qkv = (
|
|
self.qkv(hidden_states)
|
|
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
|
.permute(2, 0, 3, 1, 4)
|
|
)
|
|
# q, k, v with shape (batch_size * nHead, height * width, channel)
|
|
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
|
|
|
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
|
|
|
if self.use_rel_pos:
|
|
attn_weights = self.add_decomposed_rel_pos(
|
|
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
|
)
|
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
|
|
|
# replace dropout process with added DropoutForParallelInput layer
|
|
# origin code:
|
|
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
attn_probs = self.dropout_layer(attn_weights)
|
|
|
|
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
|
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
|
|
|
attn_output = self.proj(attn_output)
|
|
|
|
if output_attentions:
|
|
outputs = (attn_output, attn_weights)
|
|
else:
|
|
outputs = (attn_output, None)
|
|
|
|
return outputs
|
|
|
|
return forward
|