[shardformer]delete xformers (#5859)

* delete xformers

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5874/head
flybird11111 5 months ago committed by GitHub
parent eaea88cf9e
commit 773d9f964a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,3 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
@ -1005,115 +1004,6 @@ class BertPipelineForwards:
return {"hidden_states": hidden_states}
def get_bert_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.bert.modeling_bert import BertAttention
def forward(
self: BertAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
final_attention_mask = None
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
final_attention_mask = relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
final_attention_mask = relative_position_scores_query + relative_position_scores_key
scale = 1 / math.sqrt(self.attention_head_size)
if attention_mask is not None:
if final_attention_mask != None:
final_attention_mask = final_attention_mask * scale + attention_mask
else:
final_attention_mask = attention_mask
if final_attention_mask is not None:
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
tgt_len = key_layer.size()[2]
final_attention_mask = final_attention_mask.expand(
batch_size, self.num_attention_heads, src_len, tgt_len
).contiguous()
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
value_layer = value_layer.permute(0, 2, 1, 3).contiguous()
context_layer = me_attention(
query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale
)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, None)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
return forward
def get_jit_fused_bert_self_output_forward():
from transformers.models.bert.modeling_bert import BertSelfOutput

@ -714,93 +714,6 @@ class BloomPipelineForwards:
return {"hidden_states": hidden_states}
def get_bloom_flash_attention_forward(enable_jit_fused=False):
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.bloom.modeling_bloom import BloomAttention
def forward(
self: BloomAttention,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, tgt_len, _, _ = query_layer.size()
_, kv_length, _, _ = key_layer.size()
proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
query_layer = query_layer.contiguous().view(*proj_shape)
key_layer = key_layer.contiguous().view(*proj_shape)
value_layer = value_layer.contiguous().view(*proj_shape)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
tgt_len = key_layer.size()[1]
attention_numerical_mask = torch.zeros(
(batch_size, self.num_heads, tgt_len, kv_length),
dtype=torch.float32,
device=query_layer.device,
requires_grad=True,
)
attention_numerical_mask = (
attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
)
attention_numerical_mask = torch.masked_fill(
attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
)
attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype)
context_layer = me_attention(
query_layer,
key_layer,
value_layer,
attn_bias=attention_numerical_mask,
scale=self.inv_norm_factor,
p=self.attention_dropout.p,
)
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# TODO to replace with the bias_dropout_add function in jit
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present, None)
return outputs
return forward
def get_jit_fused_bloom_attention_forward():
from transformers.models.bloom.modeling_bloom import BloomAttention

@ -1,9 +1,4 @@
import math
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
def forward_fn():
@ -45,163 +40,3 @@ def forward_fn():
return outputs
return forward
def get_sam_flash_attention_forward():
from transformers.models.sam.modeling_sam import SamAttention
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
return hidden_states
def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward(
self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None
) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
# Separate into heads
query = _separate_heads(query, self.num_attention_heads)
key = _separate_heads(key, self.num_attention_heads)
value = _separate_heads(value, self.num_attention_heads)
# SamAttention
_, _, _, c_per_head = query.shape
bias = None
if attention_similarity is not None:
bias = attention_similarity
scale = 1.0 / math.sqrt(c_per_head)
out = me_attention(query, key, value, attn_bias=bias, scale=scale)
out = _recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
return forward
def get_sam_vision_flash_attention_forward():
from transformers.models.sam.modeling_sam import SamVisionAttention
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
def add_decomposed_rel_pos(
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`torch.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, nHead, dim = query.shape
reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
return rel_pos
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int):
size of the query.
k_size (int):
size of key k.
rel_pos (`torch.Tensor`):
relative position embeddings (L, channel).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, batch_size, nHead, height * width, channel)
qkv = (
self.qkv(hidden_states)
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
.permute(2, 0, 1, 3, 4)
)
query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
rel_pos = None
if self.use_rel_pos:
rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))
attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)
attn_output = attn_output.reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
outputs = (attn_output, None)
return outputs
return forward

@ -11,7 +11,6 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward,
get_jit_fused_bert_intermediate_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
@ -49,7 +48,6 @@ class BertPolicy(Policy):
BertLayer,
BertModel,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
@ -218,16 +216,6 @@ class BertPolicy(Policy):
target_key=BertEmbeddings,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_bert_flash_attention_forward(),
},
policy=policy,
target_key=BertSelfAttention,
)
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(

@ -11,14 +11,13 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.bloom import (
BloomPipelineForwards,
build_bloom_alibi_tensor_fn,
get_bloom_flash_attention_forward,
get_bloom_sequence_parallel_forward_fn,
get_jit_fused_bloom_attention_forward,
get_jit_fused_bloom_gelu_forward,
get_jit_fused_bloom_mlp_forward,
get_lm_forward_with_dist_cross_entropy,
)
from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
from ..modeling.jit import get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -165,16 +164,6 @@ class BloomPolicy(Policy):
target_key=BloomModel,
)
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_bloom_flash_attention_forward(),
"dropout_add": get_dropout_add_func(),
},
policy=policy,
target_key=BloomAttention,
)
# enable jit fused operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(

@ -1,5 +1,3 @@
import warnings
import colossalai.shardformer.layer as col_nn
from ..modeling.sam import forward_fn
@ -212,24 +210,6 @@ class SamPolicy(Policy):
target_key=SamTwoWayTransformer,
)
# use flash attention
if self.shard_config.enable_flash_attention:
warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.")
# self.append_or_create_method_replacement(
# description={
# "forward": get_sam_flash_attention_forward(),
# },
# policy=policy,
# target_key=SamAttention,
# )
# self.append_or_create_method_replacement(
# description={
# "forward": get_sam_vision_flash_attention_forward(),
# },
# policy=policy,
# target_key=SamVisionAttention,
# )
return policy
def postprocess(self):

@ -71,8 +71,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
@ -95,8 +95,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
@ -155,8 +155,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>

Loading…
Cancel
Save