[format]colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp (#566)

pull/673/head
Jie Zhu 2022-03-31 15:01:51 +08:00 committed by binmakeswell
parent 5835631218
commit 0f1da44e5e
1 changed files with 89 additions and 133 deletions

View File

@ -2,158 +2,117 @@
* https://github.com/NVIDIA/apex * https://github.com/NVIDIA/apex
* with minor changes. */ * with minor changes. */
#include "compat.h"
#include <cassert>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
#include <cassert>
#include "compat.h"
namespace { namespace {
void compute_n1_n2( void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
at::Tensor input, int &n2) {
at::IntArrayRef normalized_shape, int idiff = input.ndimension() - normalized_shape.size();
int& n1, n2 = 1;
int& n2) { for (int i = 0; i < (int)normalized_shape.size(); ++i) {
int idiff = input.ndimension() - normalized_shape.size(); assert(input.sizes()[i + idiff] == normalized_shape[i]);
n2 = 1; n2 *= normalized_shape[i];
for (int i = 0; i < (int)normalized_shape.size(); ++i) { }
assert( input.sizes()[i+idiff] == normalized_shape[i] ); n1 = 1;
n2 *= normalized_shape[i]; for (int i = 0; i < idiff; ++i) {
} n1 *= input.sizes()[i];
n1 = 1; }
for (int i = 0; i < idiff; ++i) { }
n1 *= input.sizes()[i];
void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma,
at::Tensor beta) {
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim)
.equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
} }
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input, normalized_shape, n1, n2);
} }
void check_args( void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, int &n1, int &n2) {
at::Tensor gamma, check_args(input, normalized_shape, n1, n2);
at::Tensor beta check_args(normalized_shape, gamma, beta);
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
} }
} // namespace
void check_args( void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at::Tensor input, at::Tensor *input, int n1, int n2,
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape, at::Tensor *gamma,
int& n1, at::Tensor *beta, double epsilon);
int& n2
)
{
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) \
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta,
double epsilon) {
std::vector<at::Tensor> layer_norm_affine(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gamma); CHECK_INPUT(gamma);
CHECK_INPUT(beta); CHECK_INPUT(beta);
int n1, n2; int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2); check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output = at::empty_like( at::Tensor output =
input, gamma.options().dtype(gamma.scalar_type())); at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty( at::Tensor mean =
{n1}, input.options().dtype(at::ScalarType::Float)); at::empty({n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape,
normalized_shape, &gamma, &beta, epsilon); &gamma, &beta, epsilon);
return {output, mean, invvar}; return {output, mean, invvar};
} }
void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
at::Tensor *invvar, at::Tensor *input, int n1,
int n2, at::IntArrayRef normalized_shape,
at::Tensor *gamma, at::Tensor *beta,
double epsilon, at::Tensor *grad_input,
at::Tensor *grad_gamma, at::Tensor *grad_beta);
void cuda_layer_norm_gradient( std::vector<at::Tensor>
at::Tensor* dout, layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar,
at::Tensor* mean, at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor* invvar, at::Tensor gamma, at::Tensor beta, double epsilon) {
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta
);
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout); CHECK_INPUT(dout);
CHECK_INPUT(mean); CHECK_INPUT(mean);
@ -169,17 +128,14 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor grad_beta = at::empty_like(beta); at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon, normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta); &grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta}; return {grad_input, grad_gamma, grad_beta};
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
"LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, m.def("backward_affine", &layer_norm_gradient_affine,
"LayerNorm backward (CUDA)"); "LayerNorm backward (CUDA)");
} }