mirror of https://github.com/hpcaitech/ColossalAI
parent
27327a4c90
commit
622f863291
|
@ -281,7 +281,7 @@ def vocab_parallel_classifier_3d(
|
|||
|
||||
|
||||
@torch.jit.script
|
||||
def norm_forward(x, mean, sqr_mean, weight, bias, eps):
|
||||
def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float):
|
||||
mu = x - mean
|
||||
var = sqr_mean - mean**2
|
||||
sigma = torch.sqrt(var + eps)
|
||||
|
@ -292,7 +292,7 @@ def norm_forward(x, mean, sqr_mean, weight, bias, eps):
|
|||
|
||||
|
||||
@torch.jit.script
|
||||
def norm_backward(grad, mu, sigma, weight):
|
||||
def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
|
||||
# dbias, dweight = grad, grad * mu / sigma
|
||||
dz = grad * weight
|
||||
dmu = dz / sigma
|
||||
|
|
Loading…
Reference in New Issue