[hotfix] Jit type hint #2161 (#2164)

pull/2168/merge
アマデウス 2022-12-22 10:17:03 +08:00 committed by GitHub
parent 27327a4c90
commit 622f863291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

4
colossalai/nn/layer/parallel_3d/_operation.py Normal file → Executable file
View File

@ -281,7 +281,7 @@ def vocab_parallel_classifier_3d(
@torch.jit.script
def norm_forward(x, mean, sqr_mean, weight, bias, eps):
def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float):
mu = x - mean
var = sqr_mean - mean**2
sigma = torch.sqrt(var + eps)
@ -292,7 +292,7 @@ def norm_forward(x, mean, sqr_mean, weight, bias, eps):
@torch.jit.script
def norm_backward(grad, mu, sigma, weight):
def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
# dbias, dweight = grad, grad * mu / sigma
dz = grad * weight
dmu = dz / sigma