diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp index 8571f5f71..4690277e6 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp @@ -2,11 +2,13 @@ * https://github.com/NVIDIA/apex * with minor changes. */ -#include "compat.h" -#include #include + +#include #include +#include "compat.h" + namespace { void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, @@ -65,7 +67,7 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape, check_args(input, normalized_shape, n1, n2); check_args(normalized_shape, gamma, beta); } -} // namespace +} // namespace void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, at::Tensor *input, int n1, int n2, @@ -73,17 +75,16 @@ void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, at::Tensor *beta, double epsilon); #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ +#define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) std::vector layer_norm_affine(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, double epsilon) { - CHECK_INPUT(input); CHECK_INPUT(gamma); CHECK_INPUT(beta); @@ -109,11 +110,10 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean, double epsilon, at::Tensor *grad_input, at::Tensor *grad_gamma, at::Tensor *grad_beta); -std::vector -layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar, - at::Tensor input, at::IntArrayRef normalized_shape, - at::Tensor gamma, at::Tensor beta, double epsilon) { - +std::vector layer_norm_gradient_affine( + at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input, + at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, + double epsilon) { CHECK_INPUT(dout); CHECK_INPUT(mean); CHECK_INPUT(invvar);