mirror of https://github.com/hpcaitech/ColossalAI
51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .linear import Linear
|
|
from colossalai.kernel.jit import bias_gelu_impl
|
|
|
|
|
|
class TransformerMLP(nn.Module):
|
|
"""MLP.
|
|
MLP will take the input with h hidden state, project it to 4*h
|
|
hidden dimension, perform nonlinear transformation, and project the
|
|
state back into h hidden dimension. At the end, dropout is also
|
|
applied.
|
|
"""
|
|
|
|
def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True):
|
|
super(TransformerMLP, self).__init__()
|
|
|
|
# Project to 4h.
|
|
self.dense_h_to_4h = Linear(
|
|
hidden_size,
|
|
int(hidden_size*mlp_ratio),
|
|
skip_bias_add=True)
|
|
|
|
self.bias_gelu_fusion = fuse_gelu
|
|
self.activation_func = F.gelu
|
|
|
|
# Project back to h.
|
|
self.dense_4h_to_h = Linear(
|
|
int(hidden_size*mlp_ratio),
|
|
hidden_size,
|
|
skip_bias_add=True)
|
|
|
|
def forward(self, hidden_states):
|
|
# hidden states should be in the shape of [s, b, h]
|
|
# it will be projects into [s, b, 4h]
|
|
# and projected back to [s, b, h]
|
|
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
|
|
|
|
if self.bias_gelu_fusion:
|
|
intermediate_parallel = \
|
|
bias_gelu_impl(intermediate_parallel, bias_parallel)
|
|
else:
|
|
intermediate_parallel = \
|
|
self.activation_func(intermediate_parallel + bias_parallel)
|
|
|
|
# [s, b, h]
|
|
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
|
|
return output, output_bias
|