[NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937)

pull/997/head
BoxiangW 2022-05-13 15:07:19 +08:00 committed by binmakeswell
parent 58580b50fe
commit 872aa413c2
1 changed files with 25 additions and 25 deletions

View File

@ -15,7 +15,8 @@
#define BLOCK_SIZE 512
#define ILP 4
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
@ -28,24 +29,25 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
}
typedef enum {
MOMENT_MODE_0 = 0, // L2 regularization mode
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
MOMENT_MODE_0 = 0, // L2 regularization mode
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
} adamMode_t;
std::tuple<at::Tensor, at::Tensor>
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
using MATH_T = float;
template <typename T> struct LAMBStage1Functor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta3,
const float beta1_correction, const float beta2_correction,
const float epsilon, adamMode_t mode, const float decay,
const float *global_grad_norm, const float max_global_grad_norm) {
template <typename T>
struct LAMBStage1Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta3,
const float beta1_correction, const float beta2_correction,
const float epsilon, adamMode_t mode, const float decay,
const float *global_grad_norm, const float max_global_grad_norm) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
@ -89,8 +91,7 @@ template <typename T> struct LAMBStage1Functor {
i_start += blockDim.x) {
// load
load_store(l_g, g, 0, i_start);
if (decay != 0)
load_store(l_p, p, 0, i_start);
if (decay != 0) load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
@ -204,12 +205,12 @@ template <typename T> struct LAMBStage1Functor {
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template <typename T> struct LAMBStage2Functor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_param_norm,
const float *per_tensor_update_norm, const float learning_rate,
const float decay, bool use_nvlamb) {
template <typename T>
struct LAMBStage2Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_param_norm, const float *per_tensor_update_norm,
const float learning_rate, const float decay, bool use_nvlamb) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
@ -310,8 +311,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
// Handle grad averaging mode
float beta3 = 1.0f;
if (grad_averaging == 1)
beta3 = 1 - beta1;
if (grad_averaging == 1) beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
tensor_lists.begin() + 1);
@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1, bias_correction2, epsilon,
(adamMode_t)mode, weight_decay,
global_grad_norm.DATA_PTR<float>(), max_grad_norm);)