ColossalAI/colossalai/nn/layer/parallel_1d/layers.py

167 lines
5.7 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
class Linear1D_Col(ParallelLayer):
"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
:param in_features: first dimension of matrix A.
:type in_features: int
:param output_size: second dimension of matrix A.
:type output_size: int
:param bias: If true, add bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param gather_output: If true, call all-gether on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
:type gather_output: bool, optional
"""
def __init__(self,
in_features: int,
output_size: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False):
super().__init__()
# Keep input parameters
self.input_size = in_features
self.output_size = output_size
self.gather_output = gather_output
self.skip_bias_add = not bias
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.output_size_per_partition = divide(output_size, world_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
**factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
**factory_kwargs))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
@LAYERS.register_module
class Linear1D_Row(ParallelLayer):
""" Linear layer with row parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_input: bool = False
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = not bias
# Divide the weight matrix along the last dimension.
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.input_size_per_partition = divide(in_features, world_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.out_features,
self.input_size_per_partition,
**factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(
self.out_features,
**factory_kwargs
))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def reset_parameters(self) -> None:
init.xavier_normal_(self.weight)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
input_ = input_
else:
input_ = split_forward_gather_backward(
input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add:
output = output + self.bias
return output