mirror of https://github.com/hpcaitech/ColossalAI
added skip_bias_add for non-tp linear
parent
e5b1a0c9be
commit
653b0a620e
|
@ -1,10 +1,11 @@
|
||||||
import math
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from torch import dtype, nn
|
from torch import dtype, nn
|
||||||
|
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from ... import init as init
|
from ... import init as init
|
||||||
from ..parallel_1d import *
|
from ..parallel_1d import *
|
||||||
from ..parallel_2d import *
|
from ..parallel_2d import *
|
||||||
|
@ -14,7 +15,7 @@ from ..utils import get_tensor_parallel_mode
|
||||||
from ..vanilla import *
|
from ..vanilla import *
|
||||||
from ._utils import ColossalaiModule
|
from ._utils import ColossalaiModule
|
||||||
|
|
||||||
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
||||||
|
|
||||||
_parallel_classifier = {
|
_parallel_classifier = {
|
||||||
None: VanillaClassifier,
|
None: VanillaClassifier,
|
||||||
|
@ -73,16 +74,9 @@ class Linear(ColossalaiModule):
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
tensor_parallel = get_tensor_parallel_mode()
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
if tensor_parallel is None:
|
|
||||||
layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device())
|
|
||||||
weight_initializer(layer.weight, fan_in=in_features, fan_out=out_features)
|
|
||||||
if layer.bias is not None:
|
|
||||||
bias_initializer(layer.bias, fan_in=in_features)
|
|
||||||
else:
|
|
||||||
linear_cls = _parallel_linear[tensor_parallel]
|
linear_cls = _parallel_linear[tensor_parallel]
|
||||||
gather_output = kwargs.pop('gather_output', None)
|
gather_output = kwargs.pop('gather_output', None)
|
||||||
if 'gather_output' in inspect.signature(
|
if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available
|
||||||
linear_cls.__init__).parameters.keys(): # gather_out arg is available
|
|
||||||
kwargs['gather_output'] = gather_output
|
kwargs['gather_output'] = gather_output
|
||||||
layer = linear_cls(
|
layer = linear_cls(
|
||||||
in_features,
|
in_features,
|
||||||
|
|
|
@ -1,6 +1,14 @@
|
||||||
from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm, VanillaPatchEmbedding, WrappedDropout,
|
from .layers import (
|
||||||
WrappedDropPath)
|
DropPath,
|
||||||
|
VanillaClassifier,
|
||||||
|
VanillaLayerNorm,
|
||||||
|
VanillaLinear,
|
||||||
|
VanillaPatchEmbedding,
|
||||||
|
WrappedDropout,
|
||||||
|
WrappedDropPath,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath"
|
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath",
|
||||||
|
"VanillaLinear"
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,12 +3,14 @@ from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from colossalai.context import seed
|
from colossalai.context import seed
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.registry import LAYERS
|
from colossalai.registry import LAYERS
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from torch import Tensor
|
|
||||||
from torch import nn as nn
|
|
||||||
|
|
||||||
from ..utils import to_2tuple
|
from ..utils import to_2tuple
|
||||||
|
|
||||||
|
@ -288,3 +290,52 @@ class VanillaLayerNorm(nn.Module):
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
|
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
|
||||||
|
|
||||||
|
|
||||||
|
@LAYERS.register_module
|
||||||
|
class VanillaLinear(nn.Module):
|
||||||
|
"""Linear layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): size of each input sample.
|
||||||
|
out_features (int): size of each output sample.
|
||||||
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||||
|
skip_bias_add: bool (optional, default to be false).
|
||||||
|
weight_initializer (:class:`typing.Callable`, optional):
|
||||||
|
The initializer of weight, defaults to kaiming uniform initializer.
|
||||||
|
bias_initializer (:class:`typing.Callable`, optional):
|
||||||
|
The initializer of bias, defaults to xavier uniform initializer.
|
||||||
|
|
||||||
|
More details about ``initializer`` please refer to
|
||||||
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||||
|
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||||
|
if bias:
|
||||||
|
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
weight_initializer(self.weight, fan_in=in_features, fan_out=out_features)
|
||||||
|
if self.bias is not None:
|
||||||
|
bias_initializer(self.bias, fan_in=in_features)
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
if not self.skip_bias_add:
|
||||||
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
else:
|
||||||
|
return F.linear(input, self.weight), self.bias
|
||||||
|
|
Loading…
Reference in New Issue