[NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937)

pull/997/head
BoxiangW 2022-05-13 15:07:19 +08:00 committed by binmakeswell
parent 58580b50fe
commit 872aa413c2
1 changed files with 25 additions and 25 deletions

View File

@ -15,7 +15,8 @@
#define BLOCK_SIZE 512
#define ILP 4
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
@ -32,16 +33,17 @@ typedef enum {
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
} adamMode_t;
std::tuple<at::Tensor, at::Tensor>
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
using MATH_T = float;
template <typename T> struct LAMBStage1Functor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
template <typename T>
struct LAMBStage1Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta3,
const float beta1_correction, const float beta2_correction,
const float epsilon, adamMode_t mode, const float decay,
@ -89,8 +91,7 @@ template <typename T> struct LAMBStage1Functor {
i_start += blockDim.x) {
// load
load_store(l_g, g, 0, i_start);
if (decay != 0)
load_store(l_p, p, 0, i_start);
if (decay != 0) load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
@ -204,12 +205,12 @@ template <typename T> struct LAMBStage1Functor {
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template <typename T> struct LAMBStage2Functor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_param_norm,
const float *per_tensor_update_norm, const float learning_rate,
const float decay, bool use_nvlamb) {
template <typename T>
struct LAMBStage2Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_param_norm, const float *per_tensor_update_norm,
const float learning_rate, const float decay, bool use_nvlamb) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
@ -310,8 +311,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
// Handle grad averaging mode
float beta3 = 1.0f;
if (grad_averaging == 1)
beta3 = 1 - beta1;
if (grad_averaging == 1) beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
tensor_lists.begin() + 1);