ColossalAI/examples/tutorial/sequence_parallel/model/layers/pooler.py

29 lines
832 B
Python

import torch
import torch.nn as nn
from .linear import Linear
class Pooler(nn.Module):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size):
super(Pooler, self).__init__()
self.dense = Linear(hidden_size, hidden_size)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled