2022-11-11 09:08:17 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2023-09-11 08:24:28 +00:00
|
|
|
|
2022-11-11 09:08:17 +00:00
|
|
|
from colossalai.kernel.cuda_native import LayerNorm
|
2023-09-11 08:24:28 +00:00
|
|
|
from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train
|
|
|
|
from colossalai.legacy.nn.layer.parallel_sequence import TransformerSelfAttentionRing
|
|
|
|
|
2022-11-11 09:08:17 +00:00
|
|
|
from .dropout import get_bias_dropout_add
|
2023-09-11 08:24:28 +00:00
|
|
|
from .mlp import TransformerMLP
|
2022-11-11 09:08:17 +00:00
|
|
|
|
|
|
|
|
|
|
|
def attention_mask_func(attention_scores, attention_mask):
|
|
|
|
attention_scores.masked_fill_(attention_mask, -10000.0)
|
|
|
|
return attention_scores
|
|
|
|
|
|
|
|
|
|
|
|
class BertLayer(nn.Module):
|
|
|
|
"""A single transformer layer.
|
|
|
|
Transformer layer takes input with size [b, s, h] and returns an
|
|
|
|
output of the same size.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
layer_number,
|
|
|
|
hidden_size,
|
|
|
|
num_attention_heads,
|
|
|
|
attention_dropout,
|
|
|
|
mlp_ratio,
|
|
|
|
hidden_dropout,
|
|
|
|
is_naive_fp16,
|
|
|
|
apply_residual_connection_post_layernorm=False,
|
|
|
|
fp32_residual_connection=False,
|
|
|
|
bias_dropout_fusion: bool = True,
|
|
|
|
convert_fp16_to_fp32_in_softmax: bool = False):
|
|
|
|
super().__init__()
|
|
|
|
self.layer_number = layer_number
|
|
|
|
|
|
|
|
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
|
|
|
self.fp32_residual_connection = fp32_residual_connection
|
|
|
|
|
|
|
|
# Layernorm on the input data.
|
|
|
|
self.input_layernorm = LayerNorm(hidden_size)
|
|
|
|
|
|
|
|
# Self attention.
|
|
|
|
self.self_attention = TransformerSelfAttentionRing(
|
|
|
|
hidden_size=hidden_size,
|
|
|
|
num_attention_heads=num_attention_heads,
|
|
|
|
attention_dropout=attention_dropout,
|
|
|
|
attention_mask_func=attention_mask_func,
|
|
|
|
layer_number=layer_number,
|
|
|
|
apply_query_key_layer_scaling=True,
|
|
|
|
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
2023-09-11 08:24:28 +00:00
|
|
|
fp16=is_naive_fp16)
|
2022-11-11 09:08:17 +00:00
|
|
|
|
|
|
|
self.hidden_dropout = hidden_dropout
|
|
|
|
self.bias_dropout_fusion = bias_dropout_fusion
|
|
|
|
|
|
|
|
# Layernorm on the attention output
|
|
|
|
self.post_attention_layernorm = LayerNorm(hidden_size)
|
|
|
|
|
|
|
|
self.mlp = TransformerMLP(hidden_size=hidden_size, mlp_ratio=mlp_ratio)
|
|
|
|
|
|
|
|
def forward(self, hidden_states, attention_mask):
|
|
|
|
# hidden_states: [batch_size, sub_seq_len, hidden_size]
|
|
|
|
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
|
|
|
|
|
|
|
|
# Layer norm at the beginning of the transformer layer.
|
|
|
|
layernorm_output = self.input_layernorm(hidden_states)
|
|
|
|
|
|
|
|
# Self attention.
|
|
|
|
attention_output, attention_bias = self.self_attention(layernorm_output, attention_mask)
|
|
|
|
|
|
|
|
# Residual connection.
|
|
|
|
if self.apply_residual_connection_post_layernorm:
|
|
|
|
residual = layernorm_output
|
|
|
|
else:
|
|
|
|
residual = hidden_states
|
|
|
|
|
|
|
|
# jit scripting for a nn.module (with dropout) is not
|
|
|
|
# trigerring the fusion kernel. For now, we use two
|
|
|
|
# different nn.functional routines to account for varying
|
|
|
|
# dropout semantics during training and inference phases.
|
|
|
|
if self.bias_dropout_fusion:
|
|
|
|
if self.training:
|
|
|
|
bias_dropout_add_func = bias_dropout_add_fused_train
|
|
|
|
else:
|
|
|
|
bias_dropout_add_func = bias_dropout_add_fused_inference
|
|
|
|
else:
|
|
|
|
bias_dropout_add_func = get_bias_dropout_add(self.training)
|
|
|
|
|
|
|
|
# re-enable torch grad to enable fused optimization.
|
|
|
|
with torch.enable_grad():
|
2023-09-11 08:24:28 +00:00
|
|
|
layernorm_input = bias_dropout_add_func(attention_output, attention_bias.expand_as(residual), residual,
|
|
|
|
self.hidden_dropout)
|
2022-11-11 09:08:17 +00:00
|
|
|
|
|
|
|
# Layer norm post the self attention.
|
|
|
|
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
|
|
|
|
|
|
|
# MLP.
|
|
|
|
mlp_output, mlp_bias = self.mlp(layernorm_output)
|
|
|
|
|
|
|
|
# Second residual connection.
|
|
|
|
if self.apply_residual_connection_post_layernorm:
|
|
|
|
residual = layernorm_output
|
|
|
|
else:
|
|
|
|
residual = layernorm_input
|
|
|
|
|
|
|
|
# re-enable torch grad to enable fused optimization.
|
|
|
|
with torch.enable_grad():
|
2023-09-11 08:24:28 +00:00
|
|
|
output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout)
|
2022-11-11 09:08:17 +00:00
|
|
|
|
|
|
|
return output
|