import torch import torch.nn as nn from torch.nn import Parameter import torch.nn.functional as F import torch.nn.init as init class Linear(nn.Module): """Linear layer with column parallelism. The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, ..., A_p]. Arguments: input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias init_method: method to initialize weights. Note that bias is always set to zero. stride: For the strided linear layers. keep_master_weight_for_test: This was added for testing and should be set to False. It returns the master weights used for initialization. skip_bias_add: This was added to enable performance optimations where bias can be fused with other elementwise operations. we skip adding bias but instead return it. """ def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): super(Linear, self).__init__() # Keep input parameters self.input_size = input_size self.output_size = output_size self.skip_bias_add = skip_bias_add self.weight = Parameter(torch.empty(self.output_size, self.input_size, )) init.normal_(self.weight) if bias: self.bias = Parameter(torch.empty(self.output_size)) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: self.register_parameter('bias', None) def forward(self, input_): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None output = F.linear(input_, self.weight, bias) if self.skip_bias_add: return output, self.bias else: return output def __repr__(self): return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'