import torch import torch.nn as nn from .linear import Linear class Pooler(nn.Module): """Pooler layer. Pool hidden states of a specific token (for example start of the sequence) and add a linear transformation followed by a tanh. Arguments: hidden_size: hidden size init_method: weight initialization method for the linear layer. bias is set to zero. """ def __init__(self, hidden_size): super(Pooler, self).__init__() self.dense = Linear(hidden_size, hidden_size) def forward(self, hidden_states, sequence_index=0): # hidden_states: [b, s, h] # sequence_index: index of the token to pool. pooled = hidden_states[:, sequence_index, :] pooled = self.dense(pooled) pooled = torch.tanh(pooled) return pooled