ColossalAI/colossalai/shardformer/modeling/sam.py

204 lines
8.2 KiB
Python

import math
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
def forward_fn():
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, batch_size, nHead, height * width, channel)
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
-1).permute(2, 0, 3, 1, 4))
# q, k, v with shape (batch_size * nHead, height * width, channel)
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w,
(height, width), (height, width))
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
# replace dropout process with added DropoutForParallelInput layer
# origin code:
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = self.dropout_layer(attn_weights)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
return forward
def get_sam_flash_attention_forward():
from transformers.models.sam.modeling_sam import SamAttention
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
return hidden_states
def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward(self: SamAttention,
query: Tensor,
key: Tensor,
value: Tensor,
attention_similarity: Tensor = None) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
# Separate into heads
query = _separate_heads(query, self.num_attention_heads)
key = _separate_heads(key, self.num_attention_heads)
value = _separate_heads(value, self.num_attention_heads)
# SamAttention
_, _, _, c_per_head = query.shape
bias = None
if attention_similarity is not None:
bias = attention_similarity
scale = 1.0 / math.sqrt(c_per_head)
out = me_attention(query, key, value, attn_bias=bias, scale=scale)
out = _recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
return forward
def get_sam_vision_flash_attention_forward():
from transformers.models.sam.modeling_sam import SamVisionAttention
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
def add_decomposed_rel_pos(
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`torch.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, nHead, dim = query.shape
reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
return rel_pos
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int):
size of the query.
k_size (int):
size of key k.
rel_pos (`torch.Tensor`):
relative position embeddings (L, channel).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, batch_size, nHead, height * width, channel)
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
-1).permute(2, 0, 1, 3, 4))
query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
rel_pos = None
if self.use_rel_pos:
rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))
attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)
attn_output = attn_output.reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
outputs = (attn_output, None)
return outputs
return forward