import torch import torch.distributed as dist from torch.distributed import ProcessGroup def forward_fn(): def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: batch_size, height, width, _ = hidden_states.shape # qkv with shape (3, batch_size, nHead, height * width, channel) qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1).permute(2, 0, 3, 1, 4)) # q, k, v with shape (batch_size * nHead, height * width, channel) query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) attn_weights = (query * self.scale) @ key.transpose(-2, -1) if self.use_rel_pos: attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)) attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) # replace dropout process with added DropoutForParallelInput layer # origin code: # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, attn_probs = self.dropout_layer(attn_weights) attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) attn_output = self.proj(attn_output) if output_attentions: outputs = (attn_output, attn_weights) else: outputs = (attn_output, None) return outputs return forward