fix format (#568)

pull/673/head
Xu Kai 2022-03-31 15:13:01 +08:00 committed by binmakeswell
parent 9420d3ae31
commit 2a915a8b62
2 changed files with 41 additions and 65 deletions

View File

@ -1,5 +1,3 @@
from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention
__all__ = [ __all__ = ["LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"]
"LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"
]

View File

@ -1,63 +1,41 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu // modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
#include <torch/extension.h> #include <torch/extension.h>
void multi_tensor_scale_cuda( void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float scale); float scale);
void multi_tensor_sgd_cuda( void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float wd, float wd, float momentum, float dampening, float lr,
float momentum, bool nesterov, bool first_run,
float dampening, bool wd_after_momentum, float scale);
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale);
void multi_tensor_adam_cuda( void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float lr, const float beta1,
const float beta1, const float beta2, const float epsilon,
const float beta2, const int step, const int mode,
const float epsilon,
const int step,
const int mode,
const int bias_correction, const int bias_correction,
const float weight_decay); const float weight_decay);
void multi_tensor_lamb_cuda( void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float lr, const float beta1,
const float beta1, const float beta2, const float epsilon,
const float beta2, const int step, const int bias_correction,
const float epsilon, const float weight_decay, const int grad_averaging,
const int step, const int mode, at::Tensor global_grad_norm,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
at::Tensor global_grad_norm,
const float max_grad_norm, const float max_grad_norm,
at::optional<bool> use_nvlamb_python); at::optional<bool> use_nvlamb_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( std::tuple<at::Tensor, at::Tensor>
int chunk_size, multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python); at::optional<bool> per_tensor_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
{
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,