ColossalAI/colossalai/nn/layer/parallel_2d/_transformer.py

221 lines
8.0 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LAYERS
from .layers import Linear2D, LayerNorm2D
from ..base_layer import ParallelLayer
@LAYERS.register_module
class TransformerMLP2D(ParallelLayer):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
:param in_features: the size of input tensor
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: int, optional
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.in_features = in_features
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear2D(
in_features,
int(mlp_ratio * in_features),
dtype=dtype,
skip_bias_add=self.skip_bias_add
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2D(
int(mlp_ratio * in_features),
in_features,
dtype=dtype,
skip_bias_add=self.skip_bias_add
)
self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention2D(ParallelLayer):
"""Self attention layer for 2D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.layernorm = LayerNorm2D(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer2D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention2D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP2D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output