From 622f863291315d50b1afa54f2f37190455ce0db2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Thu, 22 Dec 2022 10:17:03 +0800 Subject: [PATCH] [hotfix] Jit type hint #2161 (#2164) --- colossalai/nn/layer/parallel_3d/_operation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) mode change 100644 => 100755 colossalai/nn/layer/parallel_3d/_operation.py diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py old mode 100644 new mode 100755 index 885d06e6d..07869e5ad --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -281,7 +281,7 @@ def vocab_parallel_classifier_3d( @torch.jit.script -def norm_forward(x, mean, sqr_mean, weight, bias, eps): +def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float): mu = x - mean var = sqr_mean - mean**2 sigma = torch.sqrt(var + eps) @@ -292,7 +292,7 @@ def norm_forward(x, mean, sqr_mean, weight, bias, eps): @torch.jit.script -def norm_backward(grad, mu, sigma, weight): +def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor): # dbias, dweight = grad, grad * mu / sigma dz = grad * weight dmu = dz / sigma