|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
import math
|
|
|
|
import colossalai
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from torch.nn import Parameter
|
|
|
|
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
|
|
|
|
from colossalai.registry import LAYERS
|
|
|
|
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
|
|
|
from colossalai.kernel import FusedScaleMaskSoftmax
|
|
|
|
from colossalai.context import seed
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module
|
|
|
|
class TransformerSelfAttentionRing(nn.Module):
|
|
|
|
"""Parallel self-attention layer abstract class.
|
|
|
|
Self-attention layer takes input with size [b, s, h]
|
|
|
|
and returns output of the same size.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
hidden_size (int): hidden size.
|
|
|
|
num_attention_heads (int): number of attention heads.
|
|
|
|
attention_dropout (float): dropout probability for attention layer.
|
|
|
|
attention_mask_func (:class:`typing.Callable`): Mask function to be applied.
|
|
|
|
layer_number (int): number of layers.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
hidden_size,
|
|
|
|
num_attention_heads,
|
|
|
|
attention_dropout,
|
|
|
|
attention_mask_func,
|
|
|
|
layer_number,
|
|
|
|
apply_query_key_layer_scaling: bool = False,
|
|
|
|
convert_fp16_to_fp32_in_softmax: bool = False,
|
|
|
|
attn_mask_type=AttnMaskType.padding,
|
|
|
|
masked_softmax_fusion=True,
|
|
|
|
fp16=False,
|
|
|
|
bf16=False
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax
|
|
|
|
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
|
|
|
self.attention_mask_func = attention_mask_func
|
|
|
|
self.layer_number = layer_number
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.num_attention_heads = num_attention_heads
|
|
|
|
self.attn_mask_type = attn_mask_type
|
|
|
|
assert self.layer_number > 0
|
|
|
|
self.attention_dropout = attention_dropout
|
|
|
|
|
|
|
|
if self.apply_query_key_layer_scaling:
|
|
|
|
self.convert_fp16_to_fp32_in_softmax = True
|
|
|
|
|
|
|
|
assert self.hidden_size % self.num_attention_heads == 0, \
|
|
|
|
'hidden size is not divisible by the number of attention heads'
|
|
|
|
|
|
|
|
self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads
|
|
|
|
|
|
|
|
self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
|
|
|
|
|
|
|
# Strided linear layer.
|
|
|
|
self.query_key_value = _Linear(
|
|
|
|
hidden_size,
|
|
|
|
3 * self.hidden_size,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.coeff = None
|
|
|
|
self.norm_factor = math.sqrt(self.hidden_size)
|
|
|
|
|
|
|
|
if self.apply_query_key_layer_scaling:
|
|
|
|
self.coeff = layer_number
|
|
|
|
self.norm_factor *= self.coeff
|
|
|
|
|
|
|
|
self.scale_mask_softmax = FusedScaleMaskSoftmax(
|
|
|
|
fp16, bf16,
|
|
|
|
self.attn_mask_type,
|
|
|
|
masked_softmax_fusion,
|
|
|
|
self.attention_mask_func,
|
|
|
|
self.convert_fp16_to_fp32_in_softmax,
|
|
|
|
self.coeff)
|
|
|
|
|
|
|
|
self.attention_dropout = nn.Dropout(attention_dropout)
|
|
|
|
|
|
|
|
# Output.
|
|
|
|
self.dense = _Linear(hidden_size,
|
|
|
|
hidden_size,
|
|
|
|
bias=True,
|
|
|
|
skip_bias_add=True)
|
|
|
|
|
|
|
|
def forward(self, hidden_states, attention_mask):
|
|
|
|
# hidden_states: [sub_seq_len, batch_size, hidden_size]
|
|
|
|
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
|
|
|
|
sub_seq_length, batch_size, hidden_size = hidden_states.size()
|
|
|
|
|
|
|
|
# =====================
|
|
|
|
# Query, Key, and Value
|
|
|
|
# =====================
|
|
|
|
|
|
|
|
# Attention heads shape change:
|
|
|
|
# [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)]
|
|
|
|
mixed_x_layer = self.query_key_value(hidden_states)
|
|
|
|
|
|
|
|
# [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size]
|
|
|
|
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads,
|
|
|
|
3 * self.hidden_size_per_attention_head)
|
|
|
|
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
|
|
|
|
|
|
|
# split into query, key and value
|
|
|
|
last_dim = mixed_x_layer.dim() - 1
|
|
|
|
last_dim_value = mixed_x_layer.size(-1)
|
|
|
|
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
|
|
|
|
'cannot be divided into query, key and value'
|
|
|
|
partition_size = last_dim_value // 3
|
|
|
|
(query_layer, key_layer, value_layer) = torch.split(
|
|
|
|
mixed_x_layer, partition_size, dim=last_dim)
|
|
|
|
|
|
|
|
# attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
|
|
|
|
output_size = (query_layer.size(1),
|
|
|
|
query_layer.size(2),
|
|
|
|
query_layer.size(0),
|
|
|
|
key_layer.size(0) * self.world_size)
|
|
|
|
|
|
|
|
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
|
|
|
|
query_layer = query_layer.view(output_size[2],
|
|
|
|
output_size[0] * output_size[1], -1)
|
|
|
|
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
|
|
|
|
key_layer = key_layer.view(key_layer.size(0),
|
|
|
|
output_size[0] * output_size[1], -1)
|
|
|
|
|
|
|
|
# attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
|
|
|
|
attention_scores = RingQK.apply(
|
|
|
|
query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size]
|
|
|
|
key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size],
|
|
|
|
batch_size,
|
|
|
|
self.num_attention_heads,
|
|
|
|
sub_seq_length
|
|
|
|
)
|
|
|
|
|
|
|
|
attention_scores /= self.norm_factor
|
|
|
|
|
|
|
|
# change view to [batch_size, num_heads, sub_seq_len, seq_len]
|
|
|
|
attention_scores = attention_scores.view(*output_size)
|
|
|
|
|
|
|
|
# change shape to [batch_size, num_heads, sub_seq_len, seq_len]
|
|
|
|
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
|
|
with seed(ParallelMode.TENSOR):
|
|
|
|
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
|
|
|
|
# context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
|
|
|
|
output_size = (value_layer.size(1),
|
|
|
|
value_layer.size(2),
|
|
|
|
query_layer.size(0),
|
|
|
|
value_layer.size(3))
|
|
|
|
|
|
|
|
# change view [sub_seq_len, batch_size * num_heads, head_size]
|
|
|
|
value_layer = value_layer.contiguous().view(value_layer.size(0),
|
|
|
|
output_size[0] * output_size[1], -1)
|
|
|
|
|
|
|
|
# # change view [b * num_heads, sub_seq_len, seq_len]
|
|
|
|
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1),
|
|
|
|
attention_probs.size(2),
|
|
|
|
attention_probs.size(3))
|
|
|
|
|
|
|
|
# matmul: [batch_size * num_heads, sub_seq_len, head_size]
|
|
|
|
context_layer = RingAV.apply(
|
|
|
|
attention_probs,
|
|
|
|
value_layer.transpose(0, 1).contiguous(),
|
|
|
|
batch_size,
|
|
|
|
self.num_attention_heads,
|
|
|
|
self.hidden_size_per_attention_head,
|
|
|
|
sub_seq_length
|
|
|
|
)
|
|
|
|
|
|
|
|
# change view [batch_size, num_heads, sub_seq_len, head_size]
|
|
|
|
context_layer = context_layer.view(*output_size)
|
|
|
|
|
|
|
|
# [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size]
|
|
|
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
|
|
|
|
|
|
|
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size]
|
|
|
|
new_context_layer_shape = context_layer.size()[:-2] + (
|
|
|
|
self.hidden_size_per_attention_head * self.num_attention_heads,)
|
|
|
|
context_layer = context_layer.view(*new_context_layer_shape)
|
|
|
|
|
|
|
|
output, bias = self.dense(context_layer)
|
|
|
|
|
|
|
|
return output, bias
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \
|
|
|
|
f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \
|
|
|
|
f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \
|
|
|
|
f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \
|
|
|
|
f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})'
|
|
|
|
|
|
|
|
|
|
|
|
class _Linear(nn.Module):
|
|
|
|
"""Linear layer with column parallelism.
|
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
|
|
its second dimension as A = [A_1, ..., A_p].
|
|
|
|
Arguments:
|
|
|
|
input_size: first dimension of matrix A.
|
|
|
|
output_size: second dimension of matrix A.
|
|
|
|
bias: If true, add bias
|
|
|
|
init_method: method to initialize weights. Note that bias is always set
|
|
|
|
to zero.
|
|
|
|
stride: For the strided linear layers.
|
|
|
|
keep_master_weight_for_test: This was added for testing and should be
|
|
|
|
set to False. It returns the master weights
|
|
|
|
used for initialization.
|
|
|
|
skip_bias_add: This was added to enable performance optimations where bias
|
|
|
|
can be fused with other elementwise operations. we skip
|
|
|
|
adding bias but instead return it.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
input_size,
|
|
|
|
output_size,
|
|
|
|
bias=True,
|
|
|
|
skip_bias_add=False):
|
|
|
|
super(_Linear, self).__init__()
|
|
|
|
|
|
|
|
# Keep input parameters
|
|
|
|
self.input_size = input_size
|
|
|
|
self.output_size = output_size
|
|
|
|
self.skip_bias_add = skip_bias_add
|
|
|
|
|
|
|
|
self.weight = Parameter(torch.empty(self.output_size,
|
|
|
|
self.input_size,
|
|
|
|
))
|
|
|
|
nn.init.xavier_normal_(self.weight)
|
|
|
|
|
|
|
|
if bias:
|
|
|
|
self.bias = Parameter(torch.empty(self.output_size))
|
|
|
|
# Always initialize bias to zero.
|
|
|
|
with torch.no_grad():
|
|
|
|
self.bias.zero_()
|
|
|
|
else:
|
|
|
|
self.register_parameter('bias', None)
|
|
|
|
|
|
|
|
def forward(self, input_):
|
|
|
|
# Matrix multiply.
|
|
|
|
bias = self.bias if not self.skip_bias_add else None
|
|
|
|
output = F.linear(input_, self.weight, bias)
|
|
|
|
|
|
|
|
if self.skip_bias_add:
|
|
|
|
return output, self.bias
|
|
|
|
else:
|
|
|
|
return output
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
|
|
|
|
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'
|