[TP] allow layernorm without bias (#750)

pull/756/head
アマデウス 2022-04-14 11:43:56 +08:00 committed by GitHub
parent 3d7dc46d33
commit b8899e0905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 134 additions and 53 deletions

View File

@ -6,9 +6,16 @@ from ..parallel_2d import LayerNorm2D
from ..parallel_2p5d import LayerNorm2p5D from ..parallel_2p5d import LayerNorm2p5D
from ..parallel_3d import LayerNorm3D from ..parallel_3d import LayerNorm3D
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
from ..vanilla import VanillaLayerNorm
from ._utils import ColossalaiModule from ._utils import ColossalaiModule
_parallel_layernorm = {'1d': LayerNorm1D, '2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D} _parallel_layernorm = {
None: VanillaLayerNorm,
"1d": LayerNorm1D,
"2d": LayerNorm2D,
"2.5d": LayerNorm2p5D,
"3d": LayerNorm3D,
}
class LayerNorm(ColossalaiModule): class LayerNorm(ColossalaiModule):
@ -16,14 +23,16 @@ class LayerNorm(ColossalaiModule):
Args: Args:
normalized_shape (int): input shape from an expected input of size. normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size. normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05 eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
""" """
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None: def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel is None: if tensor_parallel is None:
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())

View File

@ -19,7 +19,7 @@ from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..vanilla import VanillaPatchEmbedding from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule from ..colossalai_layer._utils import ColossalaiModule
@ -85,20 +85,19 @@ class LayerNorm1D(ColossalaiModule):
r""" r"""
Layer Normalization for colossalai Layer Normalization for colossalai
:param normalized_shape: input shape from an expected input Args:
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] normalized_shape (int): input shape from an expected input of size.
\times \ldots \times \text{normalized_shape}[-1]]` :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
If a single integer is used, it is treated as a singleton list, and this module will \times \ldots \times \text{normalized_shape}[-1]]`
normalize over the last dimension which is expected to be of that specific size. If a single integer is used, it is treated as a singleton list, and this module will
:type normalized_shape: int normalize over the last dimension which is expected to be of that specific size.
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05 eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
:type eps: float, optional bias (bool, optional): Whether to add a bias, defaults to ``True``.
:param dtype: The dtype of parameters, defaults to None dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
:type dtype: torch.dtype, optional
""" """
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None): def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype) norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype)
super().__init__(norm) super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args): def _load_from_state_dict(self, state_dict, prefix, *args):

View File

@ -216,10 +216,11 @@ class LayerNorm2D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size. normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
""" """
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None): def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None):
super().__init__() super().__init__()
# layer norm config # layer norm config
@ -239,13 +240,17 @@ class LayerNorm2D(ParallelLayer):
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) if bias:
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
else:
self.bias = None
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
@ -294,7 +299,9 @@ class LayerNorm2D(ParallelLayer):
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias}) local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in column groups # gather in column groups
local_state = gather_tensor_parallel_state_dict( local_state = gather_tensor_parallel_state_dict(
@ -345,13 +352,17 @@ class LayerNorm2D(ParallelLayer):
output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL) ParallelMode.PARALLEL_2D_COL)
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank, scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output) if self.bias is not None:
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output)
else:
output = torch.mul(scale, output)
return output return output

View File

@ -235,10 +235,11 @@ class LayerNorm2p5D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size. normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
""" """
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None): def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None):
super().__init__() super().__init__()
# layer norm config # layer norm config
@ -259,13 +260,17 @@ class LayerNorm2p5D(ParallelLayer):
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) if bias:
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
else:
self.bias = None
self._set_tensor_parallel_attribute() self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self): def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim) set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim)
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
@ -314,7 +319,9 @@ class LayerNorm2p5D(ParallelLayer):
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias}) local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in column groups # gather in column groups
local_state = gather_tensor_parallel_state_dict( local_state = gather_tensor_parallel_state_dict(
@ -364,15 +371,18 @@ class LayerNorm2p5D(ParallelLayer):
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank, scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size) self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output) if self.bias is not None:
bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output)
else:
output = torch.mul(scale, output)
return output return output

View File

@ -190,7 +190,7 @@ class _Layernorm3D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor: output_parallel_mode: ParallelMode) -> Tensor:
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
@ -201,8 +201,11 @@ class _Layernorm3D(torch.autograd.Function):
ctx.save_for_backward(mu, sigma, weight) ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma z = mu / sigma
output = weight * z + bias output = weight * z
if bias is not None:
output = output + bias
ctx.use_bias = bias is not None
ctx.normalized_shape = normalized_shape ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode ctx.weight_parallel_mode = weight_parallel_mode
@ -215,12 +218,17 @@ class _Layernorm3D(torch.autograd.Function):
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors mu, sigma, weight = ctx.saved_tensors
with torch.no_grad(): with torch.no_grad():
bias_grad, weight_grad = output_grad, output_grad * mu / sigma weight_grad = output_grad * mu / sigma
grads = torch.stack([bias_grad, weight_grad]).contiguous() if ctx.use_bias:
grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1])) bias_grad = output_grad
grads = all_reduce(grads, ctx.weight_parallel_mode) weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
grads = all_reduce(grads, ctx.input_parallel_mode) else:
bias_grad, weight_grad = grads[0], grads[1] bias_grad = None
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
if ctx.use_bias:
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
dz = output_grad * weight dz = output_grad * weight
dvar = dz * mu * (-0.5) * sigma**(-3) dvar = dz * mu * (-0.5) * sigma**(-3)
@ -234,7 +242,7 @@ class _Layernorm3D(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor: output_parallel_mode: ParallelMode) -> Tensor:
r"""3D parallel Layernorm. r"""3D parallel Layernorm.

View File

@ -36,10 +36,11 @@ class LayerNorm3D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size. normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12. eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
""" """
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None): def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None):
super().__init__() super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@ -51,18 +52,23 @@ class LayerNorm3D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), if bias:
dtype=dtype)) self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition,
device=get_current_device(), dtype=dtype))
else:
self.bias = None
self.variance_epsilon = eps self.variance_epsilon = eps
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self) -> None: def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth) set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
set_tensor_parallel_attribute_by_partition(self.bias, self.depth) if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
init.zeros_()(self.bias)
init.ones_()(self.weight) init.ones_()(self.weight)
if self.bias is not None:
init.zeros_()(self.bias)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
@ -104,7 +110,9 @@ class LayerNorm3D(ParallelLayer):
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias}) local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in output groups # gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ if gpc.get_local_rank(self.input_parallel_mode) == 0 and \

View File

@ -1,5 +1,7 @@
from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding, \ from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm,
WrappedDropout, WrappedDropPath VanillaPatchEmbedding, WrappedDropout, WrappedDropPath)
__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath', __all__ = [
'WrappedDropout', 'WrappedDropPath'] "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier",
"DropPath", "WrappedDropout", "WrappedDropPath"
]

View File

@ -254,3 +254,37 @@ class VanillaClassifier(nn.Module):
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return F.linear(input_, self.weight, self.bias) return F.linear(input_, self.weight, self.bias)
@LAYERS.register_module
class VanillaLayerNorm(nn.Module):
r"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
super().__init__()
self.normalized_shape = (normalized_shape,)
self.variance_epsilon = eps
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
if bias:
self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs))
else:
self.bias = None
def forward(self, x: Tensor) -> Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)