diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu index 15ac20914..54c422019 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu @@ -15,7 +15,8 @@ #define BLOCK_SIZE 512 #define ILP 4 -template __device__ __forceinline__ bool is_aligned(T *p) { +template +__device__ __forceinline__ bool is_aligned(T *p) { return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } @@ -28,24 +29,25 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, } typedef enum { - MOMENT_MODE_0 = 0, // L2 regularization mode - MOMENT_MODE_1 = 1 // Decoupled weight decay mode + MOMENT_MODE_0 = 0, // L2 regularization mode + MOMENT_MODE_1 = 1 // Decoupled weight decay mode } adamMode_t; -std::tuple -multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python); +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python); using MATH_T = float; -template struct LAMBStage1Functor { - __device__ __forceinline__ void - operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, - const float beta1, const float beta2, const float beta3, - const float beta1_correction, const float beta2_correction, - const float epsilon, adamMode_t mode, const float decay, - const float *global_grad_norm, const float max_global_grad_norm) { +template +struct LAMBStage1Functor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, + const float beta1, const float beta2, const float beta3, + const float beta1_correction, const float beta2_correction, + const float epsilon, adamMode_t mode, const float decay, + const float *global_grad_norm, const float max_global_grad_norm) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -89,8 +91,7 @@ template struct LAMBStage1Functor { i_start += blockDim.x) { // load load_store(l_g, g, 0, i_start); - if (decay != 0) - load_store(l_p, p, 0, i_start); + if (decay != 0) load_store(l_p, p, 0, i_start); load_store(l_m, m, 0, i_start); load_store(l_v, v, 0, i_start); // unpack @@ -204,12 +205,12 @@ template struct LAMBStage1Functor { // Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // It computes new parameter value. -template struct LAMBStage2Functor { - __device__ __forceinline__ void - operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, - const float *per_tensor_param_norm, - const float *per_tensor_update_norm, const float learning_rate, - const float decay, bool use_nvlamb) { +template +struct LAMBStage2Functor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, + const float *per_tensor_param_norm, const float *per_tensor_update_norm, + const float learning_rate, const float decay, bool use_nvlamb) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -310,8 +311,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, // Handle grad averaging mode float beta3 = 1.0f; - if (grad_averaging == 1) - beta3 = 1 - beta1; + if (grad_averaging == 1) beta3 = 1 - beta1; std::vector> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1); @@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, LAMBStage1Functor(), beta1, beta2, - beta3, // 1-beta1 or 1 depends on averaging mode + beta3, // 1-beta1 or 1 depends on averaging mode bias_correction1, bias_correction2, epsilon, (adamMode_t)mode, weight_decay, global_grad_norm.DATA_PTR(), max_grad_norm);)