ColossalAI/examples/tutorial/sequence_parallel/model/layers/bert_layer.py

113 lines
4.2 KiB
Python

import torch
import torch.nn as nn
from colossalai.kernel.cuda_native import LayerNorm
from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train
from colossalai.legacy.nn.layer.parallel_sequence import TransformerSelfAttentionRing
from .dropout import get_bias_dropout_add
from .mlp import TransformerMLP
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
class BertLayer(nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self,
layer_number,
hidden_size,
num_attention_heads,
attention_dropout,
mlp_ratio,
hidden_dropout,
is_naive_fp16,
apply_residual_connection_post_layernorm=False,
fp32_residual_connection=False,
bias_dropout_fusion: bool = True,
convert_fp16_to_fp32_in_softmax: bool = False):
super().__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.fp32_residual_connection = fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size)
# Self attention.
self.self_attention = TransformerSelfAttentionRing(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=attention_dropout,
attention_mask_func=attention_mask_func,
layer_number=layer_number,
apply_query_key_layer_scaling=True,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
fp16=is_naive_fp16)
self.hidden_dropout = hidden_dropout
self.bias_dropout_fusion = bias_dropout_fusion
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(hidden_size)
self.mlp = TransformerMLP(hidden_size=hidden_size, mlp_ratio=mlp_ratio)
def forward(self, hidden_states, attention_mask):
# hidden_states: [batch_size, sub_seq_len, hidden_size]
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(layernorm_output, attention_mask)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(attention_output, attention_bias.expand_as(residual), residual,
self.hidden_dropout)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout)
return output