From 0f1da44e5e88f87e2938fdb5be8ce91b3f123bf4 Mon Sep 17 00:00:00 2001 From: Jie Zhu Date: Thu, 31 Mar 2022 15:01:51 +0800 Subject: [PATCH] [format]colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp (#566) --- .../cuda_native/csrc/layer_norm_cuda.cpp | 222 +++++++----------- 1 file changed, 89 insertions(+), 133 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp index c42d91d36..8571f5f71 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp @@ -2,158 +2,117 @@ * https://github.com/NVIDIA/apex * with minor changes. */ +#include "compat.h" +#include #include #include -#include -#include "compat.h" namespace { -void compute_n1_n2( - at::Tensor input, - at::IntArrayRef normalized_shape, - int& n1, - int& n2) { - int idiff = input.ndimension() - normalized_shape.size(); - n2 = 1; - for (int i = 0; i < (int)normalized_shape.size(); ++i) { - assert( input.sizes()[i+idiff] == normalized_shape[i] ); - n2 *= normalized_shape[i]; - } - n1 = 1; - for (int i = 0; i < idiff; ++i) { - n1 *= input.sizes()[i]; +void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, + int &n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert(input.sizes()[i + idiff] == normalized_shape[i]); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, + at::Tensor beta) { + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, + int &n2) { + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input, normalized_shape, n1, n2); } -void check_args( - at::IntArrayRef normalized_shape, - at::Tensor gamma, - at::Tensor beta - ) -{ - TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); - TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, int &n1, int &n2) { + check_args(input, normalized_shape, n1, n2); + check_args(normalized_shape, gamma, beta); } +} // namespace -void check_args( - at::Tensor input, - at::IntArrayRef normalized_shape, - int& n1, - int& n2 - ) -{ - int64_t normalized_ndim = normalized_shape.size(); - - if (normalized_ndim < 1) { - std::stringstream ss; - ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " - << "containing at least one element, but got normalized_shape=" - << normalized_shape; - throw std::runtime_error(ss.str()); - } - - auto input_shape = input.sizes(); - auto input_ndim = input.dim(); - - if (input_ndim < normalized_ndim || - !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { - std::stringstream ss; - ss << "Given normalized_shape=" << normalized_shape - << ", expected input with shape [*"; - for (auto size : normalized_shape) { - ss << ", " << size; - } - ss << "], but got input of size" << input_shape; - throw std::runtime_error(ss.str()); - } - - compute_n1_n2(input,normalized_shape,n1,n2); -} - - -void check_args( - at::Tensor input, - at::IntArrayRef normalized_shape, - at::Tensor gamma, - at::Tensor beta, - int& n1, - int& n2 - ) -{ - check_args(input,normalized_shape,n1,n2); - check_args(normalized_shape,gamma,beta); -} -} - -void cuda_layer_norm( - at::Tensor* output, - at::Tensor* mean, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - at::IntArrayRef normalized_shape, - at::Tensor* gamma, - at::Tensor* beta, - double epsilon); +void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, + at::Tensor *input, int n1, int n2, + at::IntArrayRef normalized_shape, at::Tensor *gamma, + at::Tensor *beta, double epsilon); #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector layer_norm_affine(at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, + double epsilon) { -std::vector layer_norm_affine( - at::Tensor input, - at::IntArrayRef normalized_shape, - at::Tensor gamma, - at::Tensor beta, - double epsilon) { - CHECK_INPUT(input); CHECK_INPUT(gamma); CHECK_INPUT(beta); int n1, n2; check_args(input, normalized_shape, gamma, beta, n1, n2); - at::Tensor output = at::empty_like( - input, gamma.options().dtype(gamma.scalar_type())); - at::Tensor mean = at::empty( - {n1}, input.options().dtype(at::ScalarType::Float)); + at::Tensor output = + at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = + at::empty({n1}, input.options().dtype(at::ScalarType::Float)); at::Tensor invvar = at::empty_like(mean); - cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, - normalized_shape, &gamma, &beta, epsilon); + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, + &gamma, &beta, epsilon); return {output, mean, invvar}; - } +void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean, + at::Tensor *invvar, at::Tensor *input, int n1, + int n2, at::IntArrayRef normalized_shape, + at::Tensor *gamma, at::Tensor *beta, + double epsilon, at::Tensor *grad_input, + at::Tensor *grad_gamma, at::Tensor *grad_beta); -void cuda_layer_norm_gradient( - at::Tensor* dout, - at::Tensor* mean, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - at::IntArrayRef normalized_shape, - at::Tensor* gamma, - at::Tensor* beta, - double epsilon, - at::Tensor* grad_input, - at::Tensor* grad_gamma, - at::Tensor* grad_beta - ); - -std::vector layer_norm_gradient_affine( - at::Tensor dout, - at::Tensor mean, - at::Tensor invvar, - at::Tensor input, - at::IntArrayRef normalized_shape, - at::Tensor gamma, - at::Tensor beta, - double epsilon) { +std::vector +layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar, + at::Tensor input, at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, double epsilon) { CHECK_INPUT(dout); CHECK_INPUT(mean); @@ -169,17 +128,14 @@ std::vector layer_norm_gradient_affine( at::Tensor grad_beta = at::empty_like(beta); cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, - normalized_shape, &gamma, &beta, epsilon, - &grad_input, &grad_gamma, &grad_beta); + normalized_shape, &gamma, &beta, epsilon, + &grad_input, &grad_gamma, &grad_beta); return {grad_input, grad_gamma, grad_beta}; - } - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward_affine", &layer_norm_affine, - "LayerNorm forward (CUDA)"); + m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, - "LayerNorm backward (CUDA)"); + "LayerNorm backward (CUDA)"); } \ No newline at end of file