mirror of https://github.com/hpcaitech/ColossalAI
[format]colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp (#566)
parent
5835631218
commit
0f1da44e5e
|
@ -2,17 +2,14 @@
|
|||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#include "compat.h"
|
||||
#include <cassert>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
#include "compat.h"
|
||||
|
||||
namespace {
|
||||
|
||||
void compute_n1_n2(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
int& n1,
|
||||
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
|
||||
int &n2) {
|
||||
int idiff = input.ndimension() - normalized_shape.size();
|
||||
n2 = 1;
|
||||
|
@ -26,23 +23,14 @@ void compute_n1_n2(
|
|||
}
|
||||
}
|
||||
|
||||
void check_args(
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta
|
||||
)
|
||||
{
|
||||
void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma,
|
||||
at::Tensor beta) {
|
||||
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
|
||||
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
|
||||
}
|
||||
|
||||
void check_args(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
int& n1,
|
||||
int& n2
|
||||
)
|
||||
{
|
||||
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
|
||||
int &n2) {
|
||||
int64_t normalized_ndim = normalized_shape.size();
|
||||
|
||||
if (normalized_ndim < 1) {
|
||||
|
@ -57,7 +45,8 @@ void check_args(
|
|||
auto input_ndim = input.dim();
|
||||
|
||||
if (input_ndim < normalized_ndim ||
|
||||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
|
||||
!input_shape.slice(input_ndim - normalized_ndim)
|
||||
.equals(normalized_shape)) {
|
||||
std::stringstream ss;
|
||||
ss << "Given normalized_shape=" << normalized_shape
|
||||
<< ", expected input with shape [*";
|
||||
|
@ -71,42 +60,28 @@ void check_args(
|
|||
compute_n1_n2(input, normalized_shape, n1, n2);
|
||||
}
|
||||
|
||||
|
||||
void check_args(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta,
|
||||
int& n1,
|
||||
int& n2
|
||||
)
|
||||
{
|
||||
void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma, at::Tensor beta, int &n1, int &n2) {
|
||||
check_args(input, normalized_shape, n1, n2);
|
||||
check_args(normalized_shape, gamma, beta);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void cuda_layer_norm(
|
||||
at::Tensor* output,
|
||||
at::Tensor* mean,
|
||||
at::Tensor* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor* gamma,
|
||||
at::Tensor* beta,
|
||||
double epsilon);
|
||||
void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
|
||||
at::Tensor *input, int n1, int n2,
|
||||
at::IntArrayRef normalized_shape, at::Tensor *gamma,
|
||||
at::Tensor *beta, double epsilon);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<at::Tensor> layer_norm_affine(
|
||||
at::Tensor input,
|
||||
std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta,
|
||||
at::Tensor gamma, at::Tensor beta,
|
||||
double epsilon) {
|
||||
|
||||
CHECK_INPUT(input);
|
||||
|
@ -115,45 +90,29 @@ std::vector<at::Tensor> layer_norm_affine(
|
|||
int n1, n2;
|
||||
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
||||
|
||||
at::Tensor output = at::empty_like(
|
||||
input, gamma.options().dtype(gamma.scalar_type()));
|
||||
at::Tensor mean = at::empty(
|
||||
{n1}, input.options().dtype(at::ScalarType::Float));
|
||||
at::Tensor output =
|
||||
at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
|
||||
at::Tensor mean =
|
||||
at::empty({n1}, input.options().dtype(at::ScalarType::Float));
|
||||
at::Tensor invvar = at::empty_like(mean);
|
||||
|
||||
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
|
||||
normalized_shape, &gamma, &beta, epsilon);
|
||||
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape,
|
||||
&gamma, &beta, epsilon);
|
||||
|
||||
return {output, mean, invvar};
|
||||
|
||||
}
|
||||
|
||||
void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
|
||||
at::Tensor *invvar, at::Tensor *input, int n1,
|
||||
int n2, at::IntArrayRef normalized_shape,
|
||||
at::Tensor *gamma, at::Tensor *beta,
|
||||
double epsilon, at::Tensor *grad_input,
|
||||
at::Tensor *grad_gamma, at::Tensor *grad_beta);
|
||||
|
||||
void cuda_layer_norm_gradient(
|
||||
at::Tensor* dout,
|
||||
at::Tensor* mean,
|
||||
at::Tensor* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor* gamma,
|
||||
at::Tensor* beta,
|
||||
double epsilon,
|
||||
at::Tensor* grad_input,
|
||||
at::Tensor* grad_gamma,
|
||||
at::Tensor* grad_beta
|
||||
);
|
||||
|
||||
std::vector<at::Tensor> layer_norm_gradient_affine(
|
||||
at::Tensor dout,
|
||||
at::Tensor mean,
|
||||
at::Tensor invvar,
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta,
|
||||
double epsilon) {
|
||||
std::vector<at::Tensor>
|
||||
layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar,
|
||||
at::Tensor input, at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma, at::Tensor beta, double epsilon) {
|
||||
|
||||
CHECK_INPUT(dout);
|
||||
CHECK_INPUT(mean);
|
||||
|
@ -173,13 +132,10 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
|
|||
&grad_input, &grad_gamma, &grad_beta);
|
||||
|
||||
return {grad_input, grad_gamma, grad_beta};
|
||||
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward_affine", &layer_norm_affine,
|
||||
"LayerNorm forward (CUDA)");
|
||||
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
|
||||
m.def("backward_affine", &layer_norm_gradient_affine,
|
||||
"LayerNorm backward (CUDA)");
|
||||
}
|
Loading…
Reference in New Issue