mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code stype (#628)
parent
0a96338b13
commit
c1bed0d998
|
@ -1,4 +1,5 @@
|
||||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_lamb.cu
|
// modified from
|
||||||
|
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_lamb.cu
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/AccumulateType.h>
|
#include <ATen/AccumulateType.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
@ -8,57 +9,43 @@
|
||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include "type_shim.h"
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
|
#include "type_shim.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
||||||
|
|
||||||
template <typename T>
|
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
|
||||||
__device__ __forceinline__ bool is_aligned(T *p)
|
|
||||||
{
|
|
||||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int src_offset)
|
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
||||||
{
|
int src_offset) {
|
||||||
typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
|
typedef
|
||||||
|
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
|
||||||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef enum
|
typedef enum {
|
||||||
{
|
|
||||||
MOMENT_MODE_0 = 0, // L2 regularization mode
|
MOMENT_MODE_0 = 0, // L2 regularization mode
|
||||||
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
|
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
|
||||||
} adamMode_t;
|
} adamMode_t;
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
std::tuple<at::Tensor, at::Tensor>
|
||||||
int chunk_size,
|
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
|
||||||
at::Tensor noop_flag,
|
|
||||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||||
at::optional<bool> per_tensor_python);
|
at::optional<bool> per_tensor_python);
|
||||||
|
|
||||||
using MATH_T = float;
|
using MATH_T = float;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T> struct LAMBStage1Functor {
|
||||||
struct LAMBStage1Functor
|
__device__ __forceinline__ void
|
||||||
{
|
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
|
||||||
__device__ __forceinline__ void operator()(
|
const float beta1, const float beta2, const float beta3,
|
||||||
int chunk_size,
|
const float beta1_correction, const float beta2_correction,
|
||||||
volatile int *noop_gmem,
|
const float epsilon, adamMode_t mode, const float decay,
|
||||||
TensorListMetadata<4> &tl,
|
const float *global_grad_norm, const float max_global_grad_norm) {
|
||||||
const float beta1,
|
|
||||||
const float beta2,
|
|
||||||
const float beta3,
|
|
||||||
const float beta1_correction,
|
|
||||||
const float beta2_correction,
|
|
||||||
const float epsilon,
|
|
||||||
adamMode_t mode,
|
|
||||||
const float decay,
|
|
||||||
const float *global_grad_norm,
|
|
||||||
const float max_global_grad_norm)
|
|
||||||
{
|
|
||||||
// I'd like this kernel to propagate infs/nans.
|
// I'd like this kernel to propagate infs/nans.
|
||||||
// if(*noop_gmem == 1)
|
// if(*noop_gmem == 1)
|
||||||
// return;
|
// return;
|
||||||
|
@ -67,7 +54,10 @@ struct LAMBStage1Functor
|
||||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||||
int n = tl.sizes[tensor_loc];
|
int n = tl.sizes[tensor_loc];
|
||||||
|
|
||||||
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
|
float clipped_global_grad_norm =
|
||||||
|
(*global_grad_norm) > max_global_grad_norm
|
||||||
|
? (*global_grad_norm) / max_global_grad_norm
|
||||||
|
: 1.0f;
|
||||||
|
|
||||||
T *g = (T *)tl.addresses[0][tensor_loc];
|
T *g = (T *)tl.addresses[0][tensor_loc];
|
||||||
g += chunk_idx * chunk_size;
|
g += chunk_idx * chunk_size;
|
||||||
|
@ -88,19 +78,15 @@ struct LAMBStage1Functor
|
||||||
MATH_T r_m[ILP];
|
MATH_T r_m[ILP];
|
||||||
MATH_T r_v[ILP];
|
MATH_T r_v[ILP];
|
||||||
// to make things simple, we put aligned case in a different code path
|
// to make things simple, we put aligned case in a different code path
|
||||||
if (n % ILP == 0 &&
|
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) &&
|
||||||
chunk_size % ILP == 0 &&
|
is_aligned(p) && is_aligned(m) && is_aligned(v)) {
|
||||||
is_aligned(g) &&
|
|
||||||
is_aligned(p) &&
|
|
||||||
is_aligned(m) &&
|
|
||||||
is_aligned(v))
|
|
||||||
{
|
|
||||||
T l_g[ILP];
|
T l_g[ILP];
|
||||||
T l_p[ILP];
|
T l_p[ILP];
|
||||||
T l_m[ILP];
|
T l_m[ILP];
|
||||||
T l_v[ILP];
|
T l_v[ILP];
|
||||||
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x)
|
for (int i_start = threadIdx.x;
|
||||||
{
|
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||||
|
i_start += blockDim.x) {
|
||||||
// load
|
// load
|
||||||
load_store(l_g, g, 0, i_start);
|
load_store(l_g, g, 0, i_start);
|
||||||
if (decay != 0)
|
if (decay != 0)
|
||||||
|
@ -109,25 +95,19 @@ struct LAMBStage1Functor
|
||||||
load_store(l_v, v, 0, i_start);
|
load_store(l_v, v, 0, i_start);
|
||||||
// unpack
|
// unpack
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < ILP; ii++)
|
for (int ii = 0; ii < ILP; ii++) {
|
||||||
{
|
|
||||||
r_g[ii] = l_g[ii];
|
r_g[ii] = l_g[ii];
|
||||||
if (decay == 0)
|
if (decay == 0) {
|
||||||
{
|
|
||||||
r_p[ii] = MATH_T(0);
|
r_p[ii] = MATH_T(0);
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
r_p[ii] = l_p[ii];
|
r_p[ii] = l_p[ii];
|
||||||
}
|
}
|
||||||
r_m[ii] = l_m[ii];
|
r_m[ii] = l_m[ii];
|
||||||
r_v[ii] = l_v[ii];
|
r_v[ii] = l_v[ii];
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < ILP; ii++)
|
for (int ii = 0; ii < ILP; ii++) {
|
||||||
{
|
if (mode == MOMENT_MODE_0) {
|
||||||
if (mode == MOMENT_MODE_0)
|
|
||||||
{
|
|
||||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||||
// L2 on scaled grad
|
// L2 on scaled grad
|
||||||
scaled_grad = scaled_grad + decay * r_p[ii];
|
scaled_grad = scaled_grad + decay * r_p[ii];
|
||||||
|
@ -137,9 +117,7 @@ struct LAMBStage1Functor
|
||||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||||
r_p[ii] = next_m_unbiased / denom;
|
r_p[ii] = next_m_unbiased / denom;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||||
|
@ -150,8 +128,7 @@ struct LAMBStage1Functor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < ILP; ii++)
|
for (int ii = 0; ii < ILP; ii++) {
|
||||||
{
|
|
||||||
l_p[ii] = r_p[ii];
|
l_p[ii] = r_p[ii];
|
||||||
l_m[ii] = r_m[ii];
|
l_m[ii] = r_m[ii];
|
||||||
l_v[ii] = r_v[ii];
|
l_v[ii] = r_v[ii];
|
||||||
|
@ -161,39 +138,28 @@ struct LAMBStage1Functor
|
||||||
load_store(m, l_m, i_start, 0);
|
load_store(m, l_m, i_start, 0);
|
||||||
load_store(v, l_v, i_start, 0);
|
load_store(v, l_v, i_start, 0);
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
// see note in multi_tensor_scale_kernel.cu
|
// see note in multi_tensor_scale_kernel.cu
|
||||||
for (int i_start = 0;
|
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||||
i_start < n && i_start < chunk_size;
|
i_start += blockDim.x * ILP) {
|
||||||
i_start += blockDim.x * ILP)
|
|
||||||
{
|
|
||||||
MATH_T r_g[ILP];
|
MATH_T r_g[ILP];
|
||||||
MATH_T r_p[ILP];
|
MATH_T r_p[ILP];
|
||||||
MATH_T r_m[ILP];
|
MATH_T r_m[ILP];
|
||||||
MATH_T r_v[ILP];
|
MATH_T r_v[ILP];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < ILP; ii++)
|
for (int ii = 0; ii < ILP; ii++) {
|
||||||
{
|
|
||||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||||
if (i < n && i < chunk_size)
|
if (i < n && i < chunk_size) {
|
||||||
{
|
|
||||||
r_g[ii] = g[i];
|
r_g[ii] = g[i];
|
||||||
// special ?optimization? for lamb stage 1
|
// special ?optimization? for lamb stage 1
|
||||||
if (decay == 0)
|
if (decay == 0) {
|
||||||
{
|
|
||||||
r_p[ii] = MATH_T(0);
|
r_p[ii] = MATH_T(0);
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
r_p[ii] = p[i];
|
r_p[ii] = p[i];
|
||||||
}
|
}
|
||||||
r_m[ii] = m[i];
|
r_m[ii] = m[i];
|
||||||
r_v[ii] = v[i];
|
r_v[ii] = v[i];
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
r_g[ii] = MATH_T(0);
|
r_g[ii] = MATH_T(0);
|
||||||
r_p[ii] = MATH_T(0);
|
r_p[ii] = MATH_T(0);
|
||||||
r_m[ii] = MATH_T(0);
|
r_m[ii] = MATH_T(0);
|
||||||
|
@ -201,10 +167,8 @@ struct LAMBStage1Functor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < ILP; ii++)
|
for (int ii = 0; ii < ILP; ii++) {
|
||||||
{
|
if (mode == MOMENT_MODE_0) {
|
||||||
if (mode == MOMENT_MODE_0)
|
|
||||||
{
|
|
||||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||||
// L2 on scaled grad
|
// L2 on scaled grad
|
||||||
scaled_grad = scaled_grad + decay * r_p[ii];
|
scaled_grad = scaled_grad + decay * r_p[ii];
|
||||||
|
@ -214,9 +178,7 @@ struct LAMBStage1Functor
|
||||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||||
r_p[ii] = next_m_unbiased / denom;
|
r_p[ii] = next_m_unbiased / denom;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||||
|
@ -227,11 +189,9 @@ struct LAMBStage1Functor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < ILP; ii++)
|
for (int ii = 0; ii < ILP; ii++) {
|
||||||
{
|
|
||||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||||
if (i < n && i < chunk_size)
|
if (i < n && i < chunk_size) {
|
||||||
{
|
|
||||||
g[i] = r_p[ii];
|
g[i] = r_p[ii];
|
||||||
m[i] = r_m[ii];
|
m[i] = r_m[ii];
|
||||||
v[i] = r_v[ii];
|
v[i] = r_v[ii];
|
||||||
|
@ -244,19 +204,12 @@ struct LAMBStage1Functor
|
||||||
|
|
||||||
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
|
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
|
||||||
// It computes new parameter value.
|
// It computes new parameter value.
|
||||||
template <typename T>
|
template <typename T> struct LAMBStage2Functor {
|
||||||
struct LAMBStage2Functor
|
__device__ __forceinline__ void
|
||||||
{
|
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
|
||||||
__device__ __forceinline__ void operator()(
|
|
||||||
int chunk_size,
|
|
||||||
volatile int *noop_gmem,
|
|
||||||
TensorListMetadata<2> &tl,
|
|
||||||
const float *per_tensor_param_norm,
|
const float *per_tensor_param_norm,
|
||||||
const float *per_tensor_update_norm,
|
const float *per_tensor_update_norm, const float learning_rate,
|
||||||
const float learning_rate,
|
const float decay, bool use_nvlamb) {
|
||||||
const float decay,
|
|
||||||
bool use_nvlamb)
|
|
||||||
{
|
|
||||||
// I'd like this kernel to propagate infs/nans.
|
// I'd like this kernel to propagate infs/nans.
|
||||||
// if(*noop_gmem == 1)
|
// if(*noop_gmem == 1)
|
||||||
// return;
|
// return;
|
||||||
|
@ -269,11 +222,12 @@ struct LAMBStage2Functor
|
||||||
MATH_T ratio = learning_rate;
|
MATH_T ratio = learning_rate;
|
||||||
// nvlamb: apply adaptive learning rate to all parameters
|
// nvlamb: apply adaptive learning rate to all parameters
|
||||||
// otherwise, only apply to those with non-zero weight decay
|
// otherwise, only apply to those with non-zero weight decay
|
||||||
if (use_nvlamb || (decay != 0.0))
|
if (use_nvlamb || (decay != 0.0)) {
|
||||||
{
|
|
||||||
float param_norm = per_tensor_param_norm[tensor_num];
|
float param_norm = per_tensor_param_norm[tensor_num];
|
||||||
float update_norm = per_tensor_update_norm[tensor_num];
|
float update_norm = per_tensor_update_norm[tensor_num];
|
||||||
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
|
ratio = (update_norm != 0.0f && param_norm != 0.0f)
|
||||||
|
? learning_rate * (param_norm / update_norm)
|
||||||
|
: learning_rate;
|
||||||
}
|
}
|
||||||
|
|
||||||
T *update = (T *)tl.addresses[0][tensor_loc];
|
T *update = (T *)tl.addresses[0][tensor_loc];
|
||||||
|
@ -285,55 +239,44 @@ struct LAMBStage2Functor
|
||||||
n -= chunk_idx * chunk_size;
|
n -= chunk_idx * chunk_size;
|
||||||
|
|
||||||
// to make things simple, we put aligned case in a different code path
|
// to make things simple, we put aligned case in a different code path
|
||||||
if (n % ILP == 0 &&
|
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) &&
|
||||||
chunk_size % ILP == 0 &&
|
is_aligned(update)) {
|
||||||
is_aligned(p) &&
|
|
||||||
is_aligned(update))
|
|
||||||
{
|
|
||||||
T r_p[ILP];
|
T r_p[ILP];
|
||||||
T r_update[ILP];
|
T r_update[ILP];
|
||||||
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x)
|
for (int i_start = threadIdx.x;
|
||||||
{
|
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||||
|
i_start += blockDim.x) {
|
||||||
// load
|
// load
|
||||||
load_store(r_p, p, 0, i_start);
|
load_store(r_p, p, 0, i_start);
|
||||||
load_store(r_update, update, 0, i_start);
|
load_store(r_update, update, 0, i_start);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < ILP; ii++)
|
for (int ii = 0; ii < ILP; ii++) {
|
||||||
{
|
r_p[ii] = static_cast<MATH_T>(r_p[ii]) -
|
||||||
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));
|
(ratio * static_cast<MATH_T>(r_update[ii]));
|
||||||
}
|
}
|
||||||
load_store(p, r_p, i_start, 0);
|
load_store(p, r_p, i_start, 0);
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
else
|
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||||
{
|
i_start += blockDim.x * ILP) {
|
||||||
for (int i_start = 0;
|
|
||||||
i_start < n && i_start < chunk_size;
|
|
||||||
i_start += blockDim.x * ILP)
|
|
||||||
{
|
|
||||||
MATH_T r_p[ILP];
|
MATH_T r_p[ILP];
|
||||||
MATH_T r_update[ILP];
|
MATH_T r_update[ILP];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < ILP; ii++)
|
for (int ii = 0; ii < ILP; ii++) {
|
||||||
{
|
|
||||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||||
if (i < n && i < chunk_size)
|
if (i < n && i < chunk_size) {
|
||||||
{
|
|
||||||
r_p[ii] = p[i];
|
r_p[ii] = p[i];
|
||||||
r_update[ii] = update[i];
|
r_update[ii] = update[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < ILP; ii++)
|
for (int ii = 0; ii < ILP; ii++) {
|
||||||
{
|
|
||||||
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
|
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < ILP; ii++)
|
for (int ii = 0; ii < ILP; ii++) {
|
||||||
{
|
|
||||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||||
if (i < n && i < chunk_size)
|
if (i < n && i < chunk_size) {
|
||||||
{
|
|
||||||
p[i] = r_p[ii];
|
p[i] = r_p[ii];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -342,33 +285,25 @@ struct LAMBStage2Functor
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void multi_tensor_lamb_cuda(
|
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
||||||
int chunk_size,
|
|
||||||
at::Tensor noop_flag,
|
|
||||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||||
const float lr,
|
const float lr, const float beta1,
|
||||||
const float beta1,
|
const float beta2, const float epsilon,
|
||||||
const float beta2,
|
const int step, const int bias_correction,
|
||||||
const float epsilon,
|
const float weight_decay, const int grad_averaging,
|
||||||
const int step,
|
const int mode, at::Tensor global_grad_norm,
|
||||||
const int bias_correction,
|
|
||||||
const float weight_decay,
|
|
||||||
const int grad_averaging,
|
|
||||||
const int mode,
|
|
||||||
at::Tensor global_grad_norm,
|
|
||||||
const float max_grad_norm,
|
const float max_grad_norm,
|
||||||
at::optional<bool> use_nvlamb_python)
|
at::optional<bool> use_nvlamb_python) {
|
||||||
{
|
|
||||||
using namespace at;
|
using namespace at;
|
||||||
// Master weight and 32bit momentum(potentially changing) is not handled by this
|
// Master weight and 32bit momentum(potentially changing) is not handled by
|
||||||
// So we assume every tensor are all in the same type
|
// this So we assume every tensor are all in the same type
|
||||||
|
|
||||||
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
|
bool use_nvlamb =
|
||||||
|
use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
|
||||||
|
|
||||||
// Handle bias correction mode
|
// Handle bias correction mode
|
||||||
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
||||||
if (bias_correction == 1)
|
if (bias_correction == 1) {
|
||||||
{
|
|
||||||
bias_correction1 = 1 - std::pow(beta1, step);
|
bias_correction1 = 1 - std::pow(beta1, step);
|
||||||
bias_correction2 = 1 - std::pow(beta2, step);
|
bias_correction2 = 1 - std::pow(beta2, step);
|
||||||
}
|
}
|
||||||
|
@ -378,50 +313,42 @@ void multi_tensor_lamb_cuda(
|
||||||
if (grad_averaging == 1)
|
if (grad_averaging == 1)
|
||||||
beta3 = 1 - beta1;
|
beta3 = 1 - beta1;
|
||||||
|
|
||||||
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1);
|
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
|
||||||
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin() + 1, tensor_lists.begin() + 2);
|
tensor_lists.begin() + 1);
|
||||||
|
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin() + 1,
|
||||||
|
tensor_lists.begin() + 2);
|
||||||
|
|
||||||
// Compute per tensor param norm
|
// Compute per tensor param norm
|
||||||
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
|
auto param_norm_tuple =
|
||||||
|
multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
|
||||||
|
|
||||||
// We now in-place modify grad to store update before compute its norm
|
// We now in-place modify grad to store update before compute its norm
|
||||||
// Generally this is not a issue since people modify grad in step() method all the time
|
// Generally this is not a issue since people modify grad in step() method all
|
||||||
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
|
// the time We can also grab list of empty tensor to avoid this, but I'd like
|
||||||
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
|
// to save space/cpu code
|
||||||
multi_tensor_apply<4>(
|
DISPATCH_FLOAT_AND_HALF(
|
||||||
BLOCK_SIZE,
|
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
|
||||||
chunk_size,
|
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||||
noop_flag,
|
LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
|
||||||
tensor_lists,
|
|
||||||
LAMBStage1Functor<scalar_t_0>(),
|
|
||||||
beta1,
|
|
||||||
beta2,
|
|
||||||
beta3, // 1-beta1 or 1 depends on averaging mode
|
beta3, // 1-beta1 or 1 depends on averaging mode
|
||||||
bias_correction1,
|
bias_correction1, bias_correction2, epsilon,
|
||||||
bias_correction2,
|
(adamMode_t)mode, weight_decay,
|
||||||
epsilon,
|
global_grad_norm.DATA_PTR<float>(), max_grad_norm);)
|
||||||
(adamMode_t)mode,
|
|
||||||
weight_decay,
|
|
||||||
global_grad_norm.DATA_PTR<float>(),
|
|
||||||
max_grad_norm);)
|
|
||||||
|
|
||||||
// Compute update norms
|
// Compute update norms
|
||||||
auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
|
auto update_norm_tuple =
|
||||||
|
multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
|
||||||
|
|
||||||
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin() + 2);
|
std::vector<std::vector<at::Tensor>> grad_param_list(
|
||||||
|
tensor_lists.begin(), tensor_lists.begin() + 2);
|
||||||
|
|
||||||
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
|
DISPATCH_FLOAT_AND_HALF(
|
||||||
multi_tensor_apply<2>(
|
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
|
||||||
BLOCK_SIZE,
|
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list,
|
||||||
chunk_size,
|
|
||||||
noop_flag,
|
|
||||||
grad_param_list,
|
|
||||||
LAMBStage2Functor<scalar_t_0>(),
|
LAMBStage2Functor<scalar_t_0>(),
|
||||||
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
|
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
|
||||||
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
|
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
|
||||||
lr,
|
lr, weight_decay, use_nvlamb);)
|
||||||
weight_decay,
|
|
||||||
use_nvlamb);)
|
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
Loading…
Reference in New Issue