mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
8.1 KiB
207 lines
8.1 KiB
import math |
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
|
|
def forward_fn(): |
|
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: |
|
batch_size, height, width, _ = hidden_states.shape |
|
# qkv with shape (3, batch_size, nHead, height * width, channel) |
|
qkv = ( |
|
self.qkv(hidden_states) |
|
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
# q, k, v with shape (batch_size * nHead, height * width, channel) |
|
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) |
|
|
|
attn_weights = (query * self.scale) @ key.transpose(-2, -1) |
|
|
|
if self.use_rel_pos: |
|
attn_weights = self.add_decomposed_rel_pos( |
|
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) |
|
) |
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) |
|
|
|
# replace dropout process with added DropoutForParallelInput layer |
|
# origin code: |
|
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
attn_probs = self.dropout_layer(attn_weights) |
|
|
|
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) |
|
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) |
|
|
|
attn_output = self.proj(attn_output) |
|
|
|
if output_attentions: |
|
outputs = (attn_output, attn_weights) |
|
else: |
|
outputs = (attn_output, None) |
|
|
|
return outputs |
|
|
|
return forward |
|
|
|
|
|
def get_sam_flash_attention_forward(): |
|
from transformers.models.sam.modeling_sam import SamAttention |
|
|
|
try: |
|
from xformers.ops import memory_efficient_attention as me_attention |
|
except: |
|
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") |
|
|
|
def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor: |
|
batch, point_batch_size, n_tokens, channel = hidden_states.shape |
|
c_per_head = channel // num_attention_heads |
|
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) |
|
return hidden_states |
|
|
|
def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor: |
|
batch, n_tokens, n_heads, c_per_head = hidden_states.shape |
|
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) |
|
|
|
def forward( |
|
self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None |
|
) -> Tensor: |
|
# Input projections |
|
query = self.q_proj(query) |
|
key = self.k_proj(key) |
|
value = self.v_proj(value) |
|
|
|
point_batch_size = query.shape[1] |
|
# Separate into heads |
|
query = _separate_heads(query, self.num_attention_heads) |
|
key = _separate_heads(key, self.num_attention_heads) |
|
value = _separate_heads(value, self.num_attention_heads) |
|
|
|
# SamAttention |
|
_, _, _, c_per_head = query.shape |
|
bias = None |
|
if attention_similarity is not None: |
|
bias = attention_similarity |
|
|
|
scale = 1.0 / math.sqrt(c_per_head) |
|
out = me_attention(query, key, value, attn_bias=bias, scale=scale) |
|
|
|
out = _recombine_heads(out, point_batch_size) |
|
out = self.out_proj(out) |
|
|
|
return out |
|
|
|
return forward |
|
|
|
|
|
def get_sam_vision_flash_attention_forward(): |
|
from transformers.models.sam.modeling_sam import SamVisionAttention |
|
|
|
try: |
|
from xformers.ops import memory_efficient_attention as me_attention |
|
except: |
|
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") |
|
|
|
def add_decomposed_rel_pos( |
|
query: torch.Tensor, |
|
rel_pos_h: torch.Tensor, |
|
rel_pos_w: torch.Tensor, |
|
q_size: Tuple[int, int], |
|
k_size: Tuple[int, int], |
|
) -> torch.Tensor: |
|
""" |
|
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. |
|
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py |
|
|
|
Args: |
|
attn (`torch.Tensor`): |
|
attention map. |
|
query (`torch.Tensor`): |
|
query q in the attention layer with shape (batch_size, query_height * query_width, channel). |
|
rel_pos_h (`torch.Tensor`): |
|
relative position embeddings (Lh, channel) for height axis. |
|
rel_pos_w (`torch.Tensor`): |
|
relative position embeddings (Lw, channel) for width axis. |
|
q_size (tuple): |
|
spatial sequence size of query q with (query_height, query_width). |
|
k_size (tuple): |
|
spatial sequence size of key k with (key_height, key_width). |
|
|
|
Returns: |
|
attn (`torch.Tensor`): |
|
attention map with added relative positional embeddings. |
|
""" |
|
|
|
query_height, query_width = q_size |
|
key_height, key_width = k_size |
|
relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h) |
|
relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w) |
|
|
|
batch_size, _, nHead, dim = query.shape |
|
reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim) |
|
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) |
|
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) |
|
rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] |
|
rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width) |
|
return rel_pos |
|
|
|
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Get relative positional embeddings according to the relative positions of |
|
query and key sizes. |
|
|
|
Args: |
|
q_size (int): |
|
size of the query. |
|
k_size (int): |
|
size of key k. |
|
rel_pos (`torch.Tensor`): |
|
relative position embeddings (L, channel). |
|
|
|
Returns: |
|
Extracted positional embeddings according to relative positions. |
|
""" |
|
max_rel_dist = int(2 * max(q_size, k_size) - 1) |
|
# Interpolate rel pos. |
|
rel_pos_resized = F.interpolate( |
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), |
|
size=max_rel_dist, |
|
mode="linear", |
|
) |
|
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) |
|
|
|
# Scale the coords with short length if shapes for q and k are different. |
|
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) |
|
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) |
|
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) |
|
|
|
return rel_pos_resized[relative_coords.long()] |
|
|
|
def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: |
|
batch_size, height, width, _ = hidden_states.shape |
|
# qkv with shape (3, batch_size, nHead, height * width, channel) |
|
qkv = ( |
|
self.qkv(hidden_states) |
|
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1) |
|
.permute(2, 0, 1, 3, 4) |
|
) |
|
|
|
query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0) |
|
|
|
rel_pos = None |
|
if self.use_rel_pos: |
|
rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)) |
|
|
|
attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale) |
|
|
|
attn_output = attn_output.reshape(batch_size, height, width, -1) |
|
|
|
attn_output = self.proj(attn_output) |
|
|
|
outputs = (attn_output, None) |
|
|
|
return outputs |
|
|
|
return forward
|
|
|