mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
208 lines
8.1 KiB
208 lines
8.1 KiB
import math
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
|
|
def forward_fn():
|
|
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
|
batch_size, height, width, _ = hidden_states.shape
|
|
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
|
qkv = (
|
|
self.qkv(hidden_states)
|
|
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
|
.permute(2, 0, 3, 1, 4)
|
|
)
|
|
# q, k, v with shape (batch_size * nHead, height * width, channel)
|
|
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
|
|
|
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
|
|
|
if self.use_rel_pos:
|
|
attn_weights = self.add_decomposed_rel_pos(
|
|
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
|
)
|
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
|
|
|
# replace dropout process with added DropoutForParallelInput layer
|
|
# origin code:
|
|
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
attn_probs = self.dropout_layer(attn_weights)
|
|
|
|
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
|
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
|
|
|
attn_output = self.proj(attn_output)
|
|
|
|
if output_attentions:
|
|
outputs = (attn_output, attn_weights)
|
|
else:
|
|
outputs = (attn_output, None)
|
|
|
|
return outputs
|
|
|
|
return forward
|
|
|
|
|
|
def get_sam_flash_attention_forward():
|
|
from transformers.models.sam.modeling_sam import SamAttention
|
|
|
|
try:
|
|
from xformers.ops import memory_efficient_attention as me_attention
|
|
except:
|
|
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
|
|
|
def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
|
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
|
c_per_head = channel // num_attention_heads
|
|
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
|
|
return hidden_states
|
|
|
|
def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
|
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
|
|
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
|
|
|
def forward(
|
|
self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None
|
|
) -> Tensor:
|
|
# Input projections
|
|
query = self.q_proj(query)
|
|
key = self.k_proj(key)
|
|
value = self.v_proj(value)
|
|
|
|
point_batch_size = query.shape[1]
|
|
# Separate into heads
|
|
query = _separate_heads(query, self.num_attention_heads)
|
|
key = _separate_heads(key, self.num_attention_heads)
|
|
value = _separate_heads(value, self.num_attention_heads)
|
|
|
|
# SamAttention
|
|
_, _, _, c_per_head = query.shape
|
|
bias = None
|
|
if attention_similarity is not None:
|
|
bias = attention_similarity
|
|
|
|
scale = 1.0 / math.sqrt(c_per_head)
|
|
out = me_attention(query, key, value, attn_bias=bias, scale=scale)
|
|
|
|
out = _recombine_heads(out, point_batch_size)
|
|
out = self.out_proj(out)
|
|
|
|
return out
|
|
|
|
return forward
|
|
|
|
|
|
def get_sam_vision_flash_attention_forward():
|
|
from transformers.models.sam.modeling_sam import SamVisionAttention
|
|
|
|
try:
|
|
from xformers.ops import memory_efficient_attention as me_attention
|
|
except:
|
|
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
|
|
|
def add_decomposed_rel_pos(
|
|
query: torch.Tensor,
|
|
rel_pos_h: torch.Tensor,
|
|
rel_pos_w: torch.Tensor,
|
|
q_size: Tuple[int, int],
|
|
k_size: Tuple[int, int],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
|
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
|
|
|
Args:
|
|
attn (`torch.Tensor`):
|
|
attention map.
|
|
query (`torch.Tensor`):
|
|
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
|
|
rel_pos_h (`torch.Tensor`):
|
|
relative position embeddings (Lh, channel) for height axis.
|
|
rel_pos_w (`torch.Tensor`):
|
|
relative position embeddings (Lw, channel) for width axis.
|
|
q_size (tuple):
|
|
spatial sequence size of query q with (query_height, query_width).
|
|
k_size (tuple):
|
|
spatial sequence size of key k with (key_height, key_width).
|
|
|
|
Returns:
|
|
attn (`torch.Tensor`):
|
|
attention map with added relative positional embeddings.
|
|
"""
|
|
|
|
query_height, query_width = q_size
|
|
key_height, key_width = k_size
|
|
relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
|
|
relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)
|
|
|
|
batch_size, _, nHead, dim = query.shape
|
|
reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
|
|
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
|
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
|
rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
|
rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
|
|
return rel_pos
|
|
|
|
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Get relative positional embeddings according to the relative positions of
|
|
query and key sizes.
|
|
|
|
Args:
|
|
q_size (int):
|
|
size of the query.
|
|
k_size (int):
|
|
size of key k.
|
|
rel_pos (`torch.Tensor`):
|
|
relative position embeddings (L, channel).
|
|
|
|
Returns:
|
|
Extracted positional embeddings according to relative positions.
|
|
"""
|
|
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
|
# Interpolate rel pos.
|
|
rel_pos_resized = F.interpolate(
|
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
|
size=max_rel_dist,
|
|
mode="linear",
|
|
)
|
|
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
|
|
|
# Scale the coords with short length if shapes for q and k are different.
|
|
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
|
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
|
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
|
|
|
return rel_pos_resized[relative_coords.long()]
|
|
|
|
def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
|
batch_size, height, width, _ = hidden_states.shape
|
|
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
|
qkv = (
|
|
self.qkv(hidden_states)
|
|
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
|
.permute(2, 0, 1, 3, 4)
|
|
)
|
|
|
|
query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
|
|
|
|
rel_pos = None
|
|
if self.use_rel_pos:
|
|
rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))
|
|
|
|
attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)
|
|
|
|
attn_output = attn_output.reshape(batch_size, height, width, -1)
|
|
|
|
attn_output = self.proj(attn_output)
|
|
|
|
outputs = (attn_output, None)
|
|
|
|
return outputs
|
|
|
|
return forward
|