mirror of https://github.com/hpcaitech/ColossalAI
29 lines
832 B
Python
29 lines
832 B
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from .linear import Linear
|
||
|
|
||
|
|
||
|
class Pooler(nn.Module):
|
||
|
"""Pooler layer.
|
||
|
|
||
|
Pool hidden states of a specific token (for example start of the
|
||
|
sequence) and add a linear transformation followed by a tanh.
|
||
|
|
||
|
Arguments:
|
||
|
hidden_size: hidden size
|
||
|
init_method: weight initialization method for the linear layer.
|
||
|
bias is set to zero.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, hidden_size):
|
||
|
super(Pooler, self).__init__()
|
||
|
self.dense = Linear(hidden_size, hidden_size)
|
||
|
|
||
|
def forward(self, hidden_states, sequence_index=0):
|
||
|
# hidden_states: [b, s, h]
|
||
|
# sequence_index: index of the token to pool.
|
||
|
pooled = hidden_states[:, sequence_index, :]
|
||
|
pooled = self.dense(pooled)
|
||
|
pooled = torch.tanh(pooled)
|
||
|
return pooled
|