mirror of https://github.com/hpcaitech/ColossalAI
[shardformer]delete xformers (#5859)
* delete xformers * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5874/head
parent
eaea88cf9e
commit
773d9f964a
|
@ -1,4 +1,3 @@
|
|||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
@ -1005,115 +1004,6 @@ class BertPipelineForwards:
|
|||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_bert_flash_attention_forward():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.bert.modeling_bert import BertAttention
|
||||
|
||||
def forward(
|
||||
self: BertAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
final_attention_mask = None
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
final_attention_mask = relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
final_attention_mask = relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
scale = 1 / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
if final_attention_mask != None:
|
||||
final_attention_mask = final_attention_mask * scale + attention_mask
|
||||
else:
|
||||
final_attention_mask = attention_mask
|
||||
|
||||
if final_attention_mask is not None:
|
||||
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
|
||||
tgt_len = key_layer.size()[2]
|
||||
final_attention_mask = final_attention_mask.expand(
|
||||
batch_size, self.num_attention_heads, src_len, tgt_len
|
||||
).contiguous()
|
||||
|
||||
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
|
||||
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
|
||||
value_layer = value_layer.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
context_layer = me_attention(
|
||||
query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale
|
||||
)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, None)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bert_self_output_forward():
|
||||
from transformers.models.bert.modeling_bert import BertSelfOutput
|
||||
|
||||
|
|
|
@ -714,93 +714,6 @@ class BloomPipelineForwards:
|
|||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_bloom_flash_attention_forward(enable_jit_fused=False):
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
def forward(
|
||||
self: BloomAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states)
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
batch_size, tgt_len, _, _ = query_layer.size()
|
||||
|
||||
_, kv_length, _, _ = key_layer.size()
|
||||
|
||||
proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
|
||||
query_layer = query_layer.contiguous().view(*proj_shape)
|
||||
key_layer = key_layer.contiguous().view(*proj_shape)
|
||||
value_layer = value_layer.contiguous().view(*proj_shape)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
tgt_len = key_layer.size()[1]
|
||||
|
||||
attention_numerical_mask = torch.zeros(
|
||||
(batch_size, self.num_heads, tgt_len, kv_length),
|
||||
dtype=torch.float32,
|
||||
device=query_layer.device,
|
||||
requires_grad=True,
|
||||
)
|
||||
attention_numerical_mask = (
|
||||
attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
|
||||
)
|
||||
attention_numerical_mask = torch.masked_fill(
|
||||
attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
|
||||
)
|
||||
attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype)
|
||||
|
||||
context_layer = me_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_bias=attention_numerical_mask,
|
||||
scale=self.inv_norm_factor,
|
||||
p=self.attention_dropout.p,
|
||||
)
|
||||
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
# TODO to replace with the bias_dropout_add function in jit
|
||||
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
outputs = (output_tensor, present, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bloom_attention_forward():
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
|
|
|
@ -1,9 +1,4 @@
|
|||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def forward_fn():
|
||||
|
@ -45,163 +40,3 @@ def forward_fn():
|
|||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_sam_flash_attention_forward():
|
||||
from transformers.models.sam.modeling_sam import SamAttention
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
|
||||
def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
||||
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
||||
c_per_head = channel // num_attention_heads
|
||||
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
|
||||
return hidden_states
|
||||
|
||||
def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
||||
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
|
||||
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
||||
|
||||
def forward(
|
||||
self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None
|
||||
) -> Tensor:
|
||||
# Input projections
|
||||
query = self.q_proj(query)
|
||||
key = self.k_proj(key)
|
||||
value = self.v_proj(value)
|
||||
|
||||
point_batch_size = query.shape[1]
|
||||
# Separate into heads
|
||||
query = _separate_heads(query, self.num_attention_heads)
|
||||
key = _separate_heads(key, self.num_attention_heads)
|
||||
value = _separate_heads(value, self.num_attention_heads)
|
||||
|
||||
# SamAttention
|
||||
_, _, _, c_per_head = query.shape
|
||||
bias = None
|
||||
if attention_similarity is not None:
|
||||
bias = attention_similarity
|
||||
|
||||
scale = 1.0 / math.sqrt(c_per_head)
|
||||
out = me_attention(query, key, value, attn_bias=bias, scale=scale)
|
||||
|
||||
out = _recombine_heads(out, point_batch_size)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_sam_vision_flash_attention_forward():
|
||||
from transformers.models.sam.modeling_sam import SamVisionAttention
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
query: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
||||
|
||||
Args:
|
||||
attn (`torch.Tensor`):
|
||||
attention map.
|
||||
query (`torch.Tensor`):
|
||||
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
|
||||
rel_pos_h (`torch.Tensor`):
|
||||
relative position embeddings (Lh, channel) for height axis.
|
||||
rel_pos_w (`torch.Tensor`):
|
||||
relative position embeddings (Lw, channel) for width axis.
|
||||
q_size (tuple):
|
||||
spatial sequence size of query q with (query_height, query_width).
|
||||
k_size (tuple):
|
||||
spatial sequence size of key k with (key_height, key_width).
|
||||
|
||||
Returns:
|
||||
attn (`torch.Tensor`):
|
||||
attention map with added relative positional embeddings.
|
||||
"""
|
||||
|
||||
query_height, query_width = q_size
|
||||
key_height, key_width = k_size
|
||||
relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
|
||||
relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)
|
||||
|
||||
batch_size, _, nHead, dim = query.shape
|
||||
reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
||||
rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
||||
rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
|
||||
return rel_pos
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of
|
||||
query and key sizes.
|
||||
|
||||
Args:
|
||||
q_size (int):
|
||||
size of the query.
|
||||
k_size (int):
|
||||
size of key k.
|
||||
rel_pos (`torch.Tensor`):
|
||||
relative position embeddings (L, channel).
|
||||
|
||||
Returns:
|
||||
Extracted positional embeddings according to relative positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
||||
batch_size, height, width, _ = hidden_states.shape
|
||||
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
||||
qkv = (
|
||||
self.qkv(hidden_states)
|
||||
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
||||
.permute(2, 0, 1, 3, 4)
|
||||
)
|
||||
|
||||
query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
|
||||
|
||||
rel_pos = None
|
||||
if self.use_rel_pos:
|
||||
rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))
|
||||
|
||||
attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, height, width, -1)
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
|
||||
outputs = (attn_output, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
|
|
@ -11,7 +11,6 @@ import colossalai.shardformer.layer as col_nn
|
|||
from ..modeling.bert import (
|
||||
BertPipelineForwards,
|
||||
bert_sequence_parallel_forward_fn,
|
||||
get_bert_flash_attention_forward,
|
||||
get_jit_fused_bert_intermediate_forward,
|
||||
get_jit_fused_bert_output_forward,
|
||||
get_jit_fused_bert_self_output_forward,
|
||||
|
@ -49,7 +48,6 @@ class BertPolicy(Policy):
|
|||
BertLayer,
|
||||
BertModel,
|
||||
BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
|
||||
|
@ -218,16 +216,6 @@ class BertPolicy(Policy):
|
|||
target_key=BertEmbeddings,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_bert_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertSelfAttention,
|
||||
)
|
||||
|
||||
# use jit operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
|
|
|
@ -11,14 +11,13 @@ import colossalai.shardformer.layer as col_nn
|
|||
from ..modeling.bloom import (
|
||||
BloomPipelineForwards,
|
||||
build_bloom_alibi_tensor_fn,
|
||||
get_bloom_flash_attention_forward,
|
||||
get_bloom_sequence_parallel_forward_fn,
|
||||
get_jit_fused_bloom_attention_forward,
|
||||
get_jit_fused_bloom_gelu_forward,
|
||||
get_jit_fused_bloom_mlp_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
)
|
||||
from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
|
@ -165,16 +164,6 @@ class BloomPolicy(Policy):
|
|||
target_key=BloomModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_bloom_flash_attention_forward(),
|
||||
"dropout_add": get_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BloomAttention,
|
||||
)
|
||||
|
||||
# enable jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import warnings
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from ..modeling.sam import forward_fn
|
||||
|
@ -212,24 +210,6 @@ class SamPolicy(Policy):
|
|||
target_key=SamTwoWayTransformer,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.")
|
||||
# self.append_or_create_method_replacement(
|
||||
# description={
|
||||
# "forward": get_sam_flash_attention_forward(),
|
||||
# },
|
||||
# policy=policy,
|
||||
# target_key=SamAttention,
|
||||
# )
|
||||
# self.append_or_create_method_replacement(
|
||||
# description={
|
||||
# "forward": get_sam_vision_flash_attention_forward(),
|
||||
# },
|
||||
# policy=policy,
|
||||
# target_key=SamVisionAttention,
|
||||
# )
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -71,8 +71,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
|
|||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
|
@ -95,8 +95,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
|
|||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
|
@ -155,8 +155,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
|
|||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
|
|
Loading…
Reference in New Issue