[NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu code style (#667)

pull/673/head
encmps 2022-04-05 09:12:01 +08:00 committed by binmakeswell
parent e4bcff9b0f
commit 79ccfa4310
1 changed files with 102 additions and 140 deletions

View File

@ -1,4 +1,5 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
@ -8,14 +9,13 @@
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum
{
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
} adamMode_t;
@ -23,21 +23,12 @@ typedef enum
using MATH_T = float;
template <typename T_g, typename T_p>
struct AdamFunctor
{
struct AdamFunctor {
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int *noop_gmem,
TensorListMetadata<4> &tl,
const float beta1,
const float beta2,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float lr,
adamMode_t mode,
const float decay)
{
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta1_correction,
const float beta2_correction, const float epsilon, const float lr,
adamMode_t mode, const float decay) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
@ -65,27 +56,21 @@ struct AdamFunctor
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP)
{
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++)
{
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size)
{
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
}
else
{
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
@ -93,10 +78,8 @@ struct AdamFunctor
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++)
{
if (mode == ADAM_MODE_0)
{ // L2
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
@ -105,9 +88,7 @@ struct AdamFunctor
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
}
else
{ // weight decay
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
@ -118,11 +99,9 @@ struct AdamFunctor
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++)
{
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size)
{
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
@ -132,46 +111,29 @@ struct AdamFunctor
}
};
void multi_tensor_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction,
const float weight_decay)
{
const float weight_decay) {
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1)
{
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
tensor_lists[0][0].scalar_type(),
tensor_lists[1][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
tensor_lists[0][0].scalar_type(), tensor_lists[1][0].scalar_type(), 0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);)
AT_CUDA_CHECK(cudaGetLastError());
}