2022-11-02 07:12:08 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
def bias_sigmod_ele(y, bias, z):
|
|
|
|
return torch.sigmoid(y + bias) * z
|
|
|
|
|
|
|
|
|
|
|
|
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
|
|
|
|
residual: torch.Tensor, prob: float) -> torch.Tensor:
|
2022-11-02 07:49:25 +00:00
|
|
|
out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
|
2022-11-02 07:12:08 +00:00
|
|
|
out = residual + out
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
|
|
|
|
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
|
|
|
|
prob: float) -> torch.Tensor:
|
|
|
|
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
|