mirror of https://github.com/hpcaitech/ColossalAI
64 lines
2.3 KiB
Python
64 lines
2.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import Parameter
|
|
import torch.nn.functional as F
|
|
import torch.nn.init as init
|
|
|
|
|
|
class Linear(nn.Module):
|
|
"""Linear layer with column parallelism.
|
|
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
its second dimension as A = [A_1, ..., A_p].
|
|
Arguments:
|
|
input_size: first dimension of matrix A.
|
|
output_size: second dimension of matrix A.
|
|
bias: If true, add bias
|
|
init_method: method to initialize weights. Note that bias is always set
|
|
to zero.
|
|
stride: For the strided linear layers.
|
|
keep_master_weight_for_test: This was added for testing and should be
|
|
set to False. It returns the master weights
|
|
used for initialization.
|
|
skip_bias_add: This was added to enable performance optimations where bias
|
|
can be fused with other elementwise operations. we skip
|
|
adding bias but instead return it.
|
|
"""
|
|
|
|
def __init__(self,
|
|
input_size,
|
|
output_size,
|
|
bias=True,
|
|
skip_bias_add=False):
|
|
super(Linear, self).__init__()
|
|
|
|
# Keep input parameters
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
self.skip_bias_add = skip_bias_add
|
|
|
|
self.weight = Parameter(torch.empty(self.output_size,
|
|
self.input_size,
|
|
))
|
|
init.normal_(self.weight)
|
|
if bias:
|
|
self.bias = Parameter(torch.empty(self.output_size))
|
|
# Always initialize bias to zero.
|
|
with torch.no_grad():
|
|
self.bias.zero_()
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
|
|
def forward(self, input_):
|
|
# Matrix multiply.
|
|
bias = self.bias if not self.skip_bias_add else None
|
|
output = F.linear(input_, self.weight, bias)
|
|
|
|
if self.skip_bias_add:
|
|
return output, self.bias
|
|
else:
|
|
return output
|
|
|
|
def __repr__(self):
|
|
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
|
|
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'
|