mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
42 lines
1.7 KiB
42 lines
1.7 KiB
import torch |
|
|
|
|
|
def forward_fn(): |
|
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: |
|
batch_size, height, width, _ = hidden_states.shape |
|
# qkv with shape (3, batch_size, nHead, height * width, channel) |
|
qkv = ( |
|
self.qkv(hidden_states) |
|
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
# q, k, v with shape (batch_size * nHead, height * width, channel) |
|
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) |
|
|
|
attn_weights = (query * self.scale) @ key.transpose(-2, -1) |
|
|
|
if self.use_rel_pos: |
|
attn_weights = self.add_decomposed_rel_pos( |
|
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) |
|
) |
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) |
|
|
|
# replace dropout process with added DropoutForParallelInput layer |
|
# origin code: |
|
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
attn_probs = self.dropout_layer(attn_weights) |
|
|
|
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) |
|
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) |
|
|
|
attn_output = self.proj(attn_output) |
|
|
|
if output_attentions: |
|
outputs = (attn_output, attn_weights) |
|
else: |
|
outputs = (attn_output, None) |
|
|
|
return outputs |
|
|
|
return forward
|
|
|