added skip_bias_add for non-tp linear

pull/1865/head
zbian 2022-11-09 13:20:02 +08:00 committed by アマデウス
parent e5b1a0c9be
commit 653b0a620e
3 changed files with 493 additions and 440 deletions

View File

@ -1,10 +1,11 @@
import math
import inspect
import math
from typing import Callable
from colossalai.utils import get_current_device
from torch import dtype, nn
from colossalai.utils import get_current_device
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
@ -14,7 +15,7 @@ from ..utils import get_tensor_parallel_mode
from ..vanilla import *
from ._utils import ColossalaiModule
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier,
@ -73,26 +74,19 @@ class Linear(ColossalaiModule):
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel is None:
layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device())
weight_initializer(layer.weight, fan_in=in_features, fan_out=out_features)
if layer.bias is not None:
bias_initializer(layer.bias, fan_in=in_features)
else:
linear_cls = _parallel_linear[tensor_parallel]
gather_output = kwargs.pop('gather_output', None)
if 'gather_output' in inspect.signature(
linear_cls.__init__).parameters.keys(): # gather_out arg is available
kwargs['gather_output'] = gather_output
layer = linear_cls(
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
linear_cls = _parallel_linear[tensor_parallel]
gather_output = kwargs.pop('gather_output', None)
if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available
kwargs['gather_output'] = gather_output
layer = linear_cls(
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
super().__init__(layer)

View File

@ -1,6 +1,14 @@
from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm, VanillaPatchEmbedding, WrappedDropout,
WrappedDropPath)
from .layers import (
DropPath,
VanillaClassifier,
VanillaLayerNorm,
VanillaLinear,
VanillaPatchEmbedding,
WrappedDropout,
WrappedDropPath,
)
__all__ = [
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath"
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath",
"VanillaLinear"
]

View File

@ -3,12 +3,14 @@ from typing import Callable
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn as nn
from torch.nn.parameter import Parameter
from colossalai.context import seed
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch import nn as nn
from ..utils import to_2tuple
@ -288,3 +290,52 @@ class VanillaLayerNorm(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
@LAYERS.register_module
class VanillaLinear(nn.Module):
"""Linear layer.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
skip_bias_add: bool (optional, default to be false).
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
weight_initializer(self.weight, fan_in=in_features, fan_out=out_features)
if self.bias is not None:
bias_initializer(self.bias, fan_in=in_features)
def forward(self, input: Tensor) -> Tensor:
if not self.skip_bias_add:
return F.linear(input, self.weight, self.bias)
else:
return F.linear(input, self.weight), self.bias