mirror of https://github.com/hpcaitech/ColossalAI
[zero] cpu adam kernel (#288)
* Added CPU Adam * finished the cpu adam * updated the license * delete useless parameters, removed resnet * modified the method off cpu adam unittest * deleted some useless codes * removed useless codes Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: jiaruifang <fangjiarui123@gmail.com>pull/394/head
parent
90d3aef62c
commit
a3269de5c9
|
@ -0,0 +1,517 @@
|
|||
/*
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
*/
|
||||
#include "cpu_adam.h"
|
||||
#include <math.h>
|
||||
#include <omp.h>
|
||||
#include <torch/extension.h>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <string.h>
|
||||
|
||||
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adam_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
bool param_half_precision,
|
||||
bool grad_half_precision,
|
||||
float loss_scale)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
|
||||
__half* params_cast_h = NULL;
|
||||
__half* grads_cast_h = NULL;
|
||||
|
||||
if (param_half_precision) {
|
||||
params_cast_h = reinterpret_cast<__half*>(_params);
|
||||
}
|
||||
if (grad_half_precision) {
|
||||
grads_cast_h = reinterpret_cast<__half*>(grads);
|
||||
}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
AVX_Data bias2_sqrt;
|
||||
bias2_sqrt.data = SIMD_SET(_bias_correction2);
|
||||
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
AVX_Data weight_decay_4;
|
||||
if (_weight_decay > 0)
|
||||
weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
|
||||
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
|
||||
AVX_Data grad_4;
|
||||
if (grad_half_precision) {
|
||||
grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i);
|
||||
} else {
|
||||
grad_4.data = SIMD_LOAD(grads + i);
|
||||
}
|
||||
if (loss_scale > 0) {
|
||||
AVX_Data loss_scale_vec;
|
||||
loss_scale_vec.data = SIMD_SET(loss_scale);
|
||||
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
|
||||
}
|
||||
AVX_Data momentum_4;
|
||||
momentum_4.data = SIMD_LOAD(_exp_avg + i);
|
||||
|
||||
AVX_Data variance_4;
|
||||
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
|
||||
|
||||
AVX_Data param_4;
|
||||
if (param_half_precision) {
|
||||
param_4.data = SIMD_LOAD_HALF(params_cast_h + i);
|
||||
} else {
|
||||
param_4.data = SIMD_LOAD(_params + i);
|
||||
}
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
|
||||
}
|
||||
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
|
||||
momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
|
||||
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
|
||||
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
|
||||
variance_4.data = SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
|
||||
grad_4.data = SIMD_SQRT(variance_4.data);
|
||||
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
|
||||
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
|
||||
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data);
|
||||
}
|
||||
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
|
||||
|
||||
if (param_half_precision) {
|
||||
SIMD_STORE_HALF((float*)(params_cast_h + i), param_4.data);
|
||||
} else {
|
||||
SIMD_STORE(_params + i, param_4.data);
|
||||
}
|
||||
SIMD_STORE(_exp_avg + i, momentum_4.data);
|
||||
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
if (loss_scale > 0) { grad /= loss_scale; }
|
||||
float param = param_half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float momentum = _exp_avg[k];
|
||||
float variance = _exp_avg_sq[k];
|
||||
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
|
||||
momentum = momentum * _betta1;
|
||||
momentum = grad * betta1_minus1 + momentum;
|
||||
|
||||
variance = variance * _betta2;
|
||||
grad = grad * grad;
|
||||
variance = grad * betta2_minus1 + variance;
|
||||
|
||||
grad = sqrt(variance);
|
||||
grad = grad * _bias_correction2 + _eps;
|
||||
grad = momentum / grad;
|
||||
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
|
||||
param = grad * step_size + param;
|
||||
|
||||
if (param_half_precision)
|
||||
params_cast_h[k] = (__half)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
_exp_avg[k] = momentum;
|
||||
_exp_avg_sq[k] = variance;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
bool param_half_precision,
|
||||
bool grad_half_precision,
|
||||
float loss_scale)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
|
||||
__half* params_cast_h = NULL;
|
||||
__half* grads_cast_h = NULL;
|
||||
if (param_half_precision) {
|
||||
params_cast_h = reinterpret_cast<__half*>(_params);
|
||||
}
|
||||
if (grad_half_precision) {
|
||||
grads_cast_h = reinterpret_cast<__half*>(grads);
|
||||
}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
AVX_Data bias2_sqrt;
|
||||
bias2_sqrt.data = SIMD_SET(_bias_correction2);
|
||||
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
AVX_Data weight_decay_4;
|
||||
if (_weight_decay > 0)
|
||||
weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
|
||||
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
|
||||
AVX_Data grad_4[4];
|
||||
AVX_Data momentum_4[4];
|
||||
AVX_Data variance_4[4];
|
||||
AVX_Data param_4[4];
|
||||
#pragma unroll 4
|
||||
for (int j = 0; j < 4; j++) {
|
||||
if (grad_half_precision) {
|
||||
grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j);
|
||||
} else {
|
||||
grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j);
|
||||
}
|
||||
|
||||
if(loss_scale > 0) {
|
||||
AVX_Data loss_scale_vec;
|
||||
loss_scale_vec.data = SIMD_SET(loss_scale);
|
||||
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
|
||||
}
|
||||
|
||||
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
|
||||
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
|
||||
|
||||
if (param_half_precision) {
|
||||
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
|
||||
} else {
|
||||
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
|
||||
}
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
|
||||
}
|
||||
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
|
||||
momentum_4[j].data = SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
|
||||
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
|
||||
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
|
||||
variance_4[j].data = SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
|
||||
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
|
||||
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
|
||||
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
|
||||
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
|
||||
}
|
||||
param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
|
||||
if (param_half_precision) {
|
||||
SIMD_STORE_HALF((float*)(params_cast_h + i + SIMD_WIDTH * j), param_4[j].data);
|
||||
} else {
|
||||
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
|
||||
}
|
||||
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
|
||||
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((param_half_precision ? (float*)(params_cast_h + rounded_size) : _params + rounded_size),
|
||||
(grad_half_precision ? (float*)(grads_cast_h + rounded_size) : grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
param_half_precision,
|
||||
grad_half_precision,
|
||||
loss_scale);
|
||||
}
|
||||
|
||||
int create_adam_optimizer(int optimizer_id,
|
||||
float alpha = 1e-3,
|
||||
float betta1 = 0.9,
|
||||
float betta2 = 0.999,
|
||||
float eps = 1e-8,
|
||||
float weight_decay = 0,
|
||||
bool adamw_mode = true,
|
||||
bool should_log = false)
|
||||
{
|
||||
auto opt =
|
||||
std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay, adamw_mode);
|
||||
|
||||
s_optimizers[optimizer_id] = opt;
|
||||
|
||||
if (should_log){
|
||||
|
||||
std::string avx_type = "";
|
||||
#if defined(__AVX512__)
|
||||
avx_type = "AVX512";
|
||||
#else
|
||||
#if defined(__AVX256__) or defined(__AVX2__)
|
||||
avx_type = "AVX2";
|
||||
#else
|
||||
avx_type = "scalar";
|
||||
#endif
|
||||
#endif
|
||||
printf("Adam Optimizer #%d is created with %s arithmetic capability.\n",
|
||||
optimizer_id,
|
||||
avx_type.c_str());
|
||||
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
|
||||
alpha,
|
||||
betta1,
|
||||
betta2,
|
||||
weight_decay,
|
||||
(int)adamw_mode);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
bool param_half_precision,
|
||||
bool grad_half_precision,
|
||||
float loss_scale)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
__half* params_cast_h = NULL;
|
||||
__half* grads_cast_h = NULL;
|
||||
if (param_half_precision) {
|
||||
params_cast_h = reinterpret_cast<__half*>(_params);
|
||||
}
|
||||
if (grad_half_precision) {
|
||||
grads_cast_h = reinterpret_cast<__half*>(grads);
|
||||
}
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
AVX_Data bias2_sqrt;
|
||||
bias2_sqrt.data = SIMD_SET(_bias_correction2);
|
||||
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
AVX_Data weight_decay_4;
|
||||
if (_weight_decay > 0)
|
||||
weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
|
||||
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
|
||||
AVX_Data grad_4[8];
|
||||
AVX_Data momentum_4[8];
|
||||
AVX_Data variance_4[8];
|
||||
AVX_Data param_4[8];
|
||||
#pragma unroll 8
|
||||
for (int j = 0; j < 8; j++) {
|
||||
if (grad_half_precision) {
|
||||
grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j);
|
||||
} else {
|
||||
grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j);
|
||||
}
|
||||
|
||||
if (loss_scale > 0) {
|
||||
AVX_Data loss_scale_vec;
|
||||
loss_scale_vec.data = SIMD_SET(loss_scale);
|
||||
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
|
||||
}
|
||||
|
||||
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
|
||||
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
|
||||
|
||||
if (param_half_precision) {
|
||||
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
|
||||
} else {
|
||||
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
|
||||
}
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
|
||||
}
|
||||
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
|
||||
momentum_4[j].data = SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
|
||||
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
|
||||
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
|
||||
variance_4[j].data = SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
|
||||
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
|
||||
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
|
||||
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
|
||||
}
|
||||
param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
|
||||
|
||||
if (param_half_precision) {
|
||||
SIMD_STORE_HALF((float*)(params_cast_h + i + SIMD_WIDTH * j), param_4[j].data);
|
||||
} else {
|
||||
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
|
||||
}
|
||||
|
||||
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data);
|
||||
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((param_half_precision ? (float*)(params_cast_h + rounded_size) : _params + rounded_size),
|
||||
(grad_half_precision ? (float*)(grads_cast_h + rounded_size) : grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
param_half_precision,
|
||||
grad_half_precision,
|
||||
loss_scale);
|
||||
}
|
||||
|
||||
int adam_step(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
float loss_scale)
|
||||
{
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
std::shared_ptr<Adam_Optimizer> opt =
|
||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.size(0),
|
||||
(params.options().dtype() == at::kHalf),
|
||||
(grads.options().dtype() == at::kHalf),
|
||||
loss_scale);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int destroy_adam_optimizer(int optimizer_id)
|
||||
{
|
||||
s_optimizers.erase(optimizer_id);
|
||||
return 0;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("adam_update", &adam_step, "CPU Adam update (C++)");
|
||||
m.def("create_adam", &create_adam_optimizer, "CPU Adam (C++)");
|
||||
m.def("destroy_adam", &destroy_adam_optimizer, "CPU Adam destroy (C++)");
|
||||
}
|
|
@ -0,0 +1,163 @@
|
|||
/*
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <stdio.h>
|
||||
#include <cuda.h>
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#if (__x86_64__ || __i386__)
|
||||
#include <cpuid.h>
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||
#define TILE (128 * 1024 * 1024)
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
|
||||
#if defined(__AVX512__)
|
||||
#define SIMD_WIDTH 16
|
||||
#define INTV __m256i
|
||||
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
|
||||
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
|
||||
#define SIMD_SET(x) _mm512_set1_ps(x)
|
||||
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
|
||||
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
|
||||
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
|
||||
#define SIMD_LOAD_HALF(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(x)))
|
||||
#define SIMD_STORE_HALF(x, d) _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#elif defined(__AVX256__) or defined(__AVX2__)
|
||||
#define SIMD_WIDTH 8
|
||||
#define INTV __m128i
|
||||
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
|
||||
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
|
||||
#define SIMD_SET(x) _mm256_set1_ps(x)
|
||||
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
|
||||
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
|
||||
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
|
||||
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x)))
|
||||
#define SIMD_STORE_HALF(x, d) _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#endif
|
||||
|
||||
union AVX_Data {
|
||||
#if defined(__AVX512__)
|
||||
__m512 data;
|
||||
#elif defined(__AVX256__) or defined(__AVX2__)
|
||||
__m256 data;
|
||||
#endif
|
||||
// float data_f[16];
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float* _params, \
|
||||
float* grads, \
|
||||
float* _exp_avg, \
|
||||
float* _exp_avg_sq, \
|
||||
size_t _param_size, \
|
||||
bool param_half_precision = false, \
|
||||
bool grad_half_precision = false, \
|
||||
float loss_scale = -1);
|
||||
|
||||
class Adam_Optimizer {
|
||||
public:
|
||||
Adam_Optimizer(float alpha = 1e-3,
|
||||
float betta1 = 0.9,
|
||||
float betta2 = 0.999,
|
||||
float eps = 1e-8,
|
||||
float weight_decay = 0,
|
||||
bool adamw_mode = true)
|
||||
: _alpha(alpha),
|
||||
_betta1(betta1),
|
||||
_betta2(betta2),
|
||||
_eps(eps),
|
||||
_weight_decay(weight_decay),
|
||||
_betta1_t(1.0),
|
||||
_betta2_t(1.0),
|
||||
_step(0),
|
||||
_adamw_mode(adamw_mode){}
|
||||
~Adam_Optimizer(){}
|
||||
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2)
|
||||
{
|
||||
if (beta1 != _betta1 || beta2 != _betta2) {
|
||||
_step = step;
|
||||
_betta1 = beta1;
|
||||
_betta2 = beta2;
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
} else {
|
||||
_step++;
|
||||
if (_step != step) {
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
_step = step;
|
||||
} else {
|
||||
_betta1_t *= _betta1;
|
||||
_betta2_t *= _betta2;
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction)
|
||||
{
|
||||
_alpha = lr;
|
||||
_eps = epsilon;
|
||||
_weight_decay = weight_decay;
|
||||
|
||||
_bias_correction1 = 1.0f;
|
||||
_bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
_bias_correction1 = 1 - _betta1_t;
|
||||
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float _alpha;
|
||||
float _betta1;
|
||||
float _betta2;
|
||||
float _eps;
|
||||
float _weight_decay;
|
||||
|
||||
float _betta1_t;
|
||||
float _betta2_t;
|
||||
size_t _step;
|
||||
|
||||
float _bias_correction1;
|
||||
float _bias_correction2;
|
||||
|
||||
bool _adamw_mode;
|
||||
};
|
|
@ -4,7 +4,8 @@ from .fused_lamb import FusedLAMB
|
|||
from .fused_sgd import FusedSGD
|
||||
from .lamb import Lamb
|
||||
from .lars import Lars
|
||||
from .cpu_adam import CPUAdam
|
||||
|
||||
__all__ = [
|
||||
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars'
|
||||
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
# modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/adam/cpu_adam.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
import time
|
||||
from pathlib import Path
|
||||
import colossalai
|
||||
|
||||
|
||||
class CPUAdam(torch.optim.Optimizer):
|
||||
optimizer_id = 0
|
||||
|
||||
def __init__(self,
|
||||
model_params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9,
|
||||
0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
adamw_mode=True,
|
||||
loss_scale=-1,
|
||||
simd_log=False):
|
||||
|
||||
default_args = dict(lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
bias_correction=bias_correction)
|
||||
super(CPUAdam, self).__init__(model_params, default_args)
|
||||
self.opt_id = CPUAdam.optimizer_id
|
||||
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1
|
||||
self.adam_w_mode = adamw_mode
|
||||
self.loss_scale = loss_scale
|
||||
try:
|
||||
import cpu_adam
|
||||
except ImportError:
|
||||
raise ImportError('Please install colossalai from source code to use CPUAdam')
|
||||
self.cpu_adam_op = cpu_adam
|
||||
self.cpu_adam_op.create_adam(self.opt_id,
|
||||
lr,
|
||||
betas[0],
|
||||
betas[1],
|
||||
eps,
|
||||
weight_decay,
|
||||
adamw_mode,
|
||||
simd_log)
|
||||
|
||||
def __del__(self):
|
||||
self.cpu_adam_op.destroy_adam(self.opt_id)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
# intended device for step
|
||||
device = torch.device('cpu')
|
||||
|
||||
for group_id, group in enumerate(self.param_groups):
|
||||
for param_id, p in enumerate(group['params']):
|
||||
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
|
||||
"sure the cpu_offload is Ture"
|
||||
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
|
||||
# gradient momentums
|
||||
state['exp_avg'] = torch.zeros_like(p.data,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
# gradient variances
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
# memory_format=torch.preserve_format)
|
||||
|
||||
state['step'] += 1
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
self.cpu_adam_op.adam_update(self.opt_id,
|
||||
state['step'],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['weight_decay'],
|
||||
group['bias_correction'],
|
||||
p.data,
|
||||
p.grad.data,
|
||||
state['exp_avg'],
|
||||
state['exp_avg_sq'],
|
||||
self.loss_scale)
|
||||
return loss
|
|
@ -45,7 +45,9 @@ class ShardedOptimizer(ColossalaiOptimizer):
|
|||
mp_parallel_mode=ParallelMode.MODEL,
|
||||
|
||||
# cpu offload
|
||||
cpu_offload=False):
|
||||
cpu_offload=False,
|
||||
cpu_fp16_param=False,
|
||||
cpu_fp16_grad=False):
|
||||
|
||||
# TODO: add support for
|
||||
# 1. fp16 master weights
|
||||
|
@ -63,6 +65,8 @@ class ShardedOptimizer(ColossalaiOptimizer):
|
|||
|
||||
# cpu_offload
|
||||
self._cpu_offload = cpu_offload
|
||||
self._cpu_fp16_param = cpu_fp16_param
|
||||
self._cpu_fp16_grad = cpu_fp16_grad
|
||||
|
||||
# get process groups
|
||||
self._dp_parallel_mode = dp_parallel_mode
|
||||
|
@ -146,7 +150,11 @@ class ShardedOptimizer(ColossalaiOptimizer):
|
|||
|
||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
|
||||
fp32_flat_current_rank = fp16_flat_current_rank.clone().float().detach()
|
||||
# when using cpu offload, our cpu adam support fp16 paramters
|
||||
if self._cpu_fp16_param:
|
||||
fp32_flat_current_rank = fp16_flat_current_rank.detach()
|
||||
else:
|
||||
fp32_flat_current_rank = fp16_flat_current_rank.detach().float()
|
||||
device = 'cpu' if self._cpu_offload else get_current_device()
|
||||
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
||||
fp32_flat_current_rank.requires_grad = True
|
||||
|
@ -209,7 +217,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
|
|||
fp32_partition_grad = torch.zeros_like(fp32_partition_param)
|
||||
fp32_partition_param.grad = fp32_partition_grad
|
||||
|
||||
# update the parameter with zero gradients for initialization of optimizer states
|
||||
# update the parameter with zero gradients for initialization of optimizer stateus
|
||||
self._optimizer.step()
|
||||
|
||||
# remove the grad of the paramter to save memory
|
||||
|
|
10
setup.py
10
setup.py
|
@ -124,12 +124,12 @@ if build_cuda_ext:
|
|||
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
|
||||
version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
|
||||
def cuda_ext_helper(name, sources, extra_cuda_flags):
|
||||
def cuda_ext_helper(name, sources, extra_cuda_flags, extra_cxx_flags=[]):
|
||||
return CUDAExtension(name=name,
|
||||
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in sources],
|
||||
include_dirs=[os.path.join(
|
||||
this_dir, 'colossalai/kernel/cuda_native/csrc/kernels/include')],
|
||||
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
|
||||
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros + extra_cxx_flags,
|
||||
'nvcc': append_nvcc_threads(['-O3',
|
||||
'--use_fast_math'] + version_dependent_macros + extra_cuda_flags)})
|
||||
|
||||
|
@ -188,6 +188,12 @@ if build_cuda_ext:
|
|||
'kernels/general_kernels.cu',
|
||||
'kernels/cuda_util.cu'],
|
||||
extra_cuda_flags + cc_flag))
|
||||
|
||||
extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native']
|
||||
ext_modules.append(cuda_ext_helper('cpu_adam',
|
||||
['cpu_adam.cpp'],
|
||||
extra_cuda_flags,
|
||||
extra_cxx_flags))
|
||||
|
||||
setup(
|
||||
name='colossalai',
|
||||
|
|
|
@ -0,0 +1,197 @@
|
|||
# BSD 3-Clause License
|
||||
#
|
||||
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the psutil authors nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software without
|
||||
# specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import colossalai
|
||||
try:
|
||||
import cpu_adam
|
||||
except ImportError:
|
||||
raise ImportError("import cpu_adam error")
|
||||
|
||||
def torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
bias_correction,
|
||||
param,
|
||||
grad,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
loss_scale,
|
||||
use_adamw,
|
||||
):
|
||||
if loss_scale > 0:
|
||||
grad.div_(loss_scale)
|
||||
bias_correction1 = 1 - beta1 ** step
|
||||
bias_correction2 = 1 - beta2 ** step
|
||||
|
||||
if weight_decay != 0:
|
||||
if use_adamw:
|
||||
# Perform stepweight decay
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||||
|
||||
step_size = lr / bias_correction1
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
|
||||
class Test():
|
||||
def __init__(self):
|
||||
self.opt_id = 0
|
||||
|
||||
def assertLess(self, data_diff, threshold, msg):
|
||||
assert data_diff < threshold, msg
|
||||
|
||||
def assertTrue(self, condition, msg):
|
||||
assert condition, msg
|
||||
|
||||
def check_res(
|
||||
self,
|
||||
step,
|
||||
lr,
|
||||
eps,
|
||||
beta1,
|
||||
beta2,
|
||||
|
||||
weight_decay,
|
||||
shape,
|
||||
grad_dtype,
|
||||
loss_scale,
|
||||
use_adamw,
|
||||
cpu_adam_op,
|
||||
):
|
||||
p_data = torch.rand(shape, dtype=grad_dtype)
|
||||
p_data_copy = p_data.clone().float()
|
||||
p_grad = torch.rand(shape, dtype=grad_dtype)
|
||||
if loss_scale > 0:
|
||||
p_grad.mul_(loss_scale)
|
||||
p_grad_copy = p_grad.clone().float()
|
||||
exp_avg = torch.rand(shape)
|
||||
exp_avg_copy = exp_avg.clone()
|
||||
exp_avg_sq = torch.rand(shape)
|
||||
exp_avg_sq_copy = exp_avg_sq.clone()
|
||||
|
||||
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, use_adamw, True)
|
||||
cpu_adam_op.adam_update(
|
||||
self.opt_id,
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
True,
|
||||
p_data.view(-1), # fp32 data
|
||||
p_grad.view(-1), # fp32 grad
|
||||
exp_avg.view(-1),
|
||||
exp_avg_sq.view(-1),
|
||||
loss_scale,
|
||||
)
|
||||
|
||||
torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
True,
|
||||
p_data_copy, # fp32 data
|
||||
p_grad_copy, # fp32 grad
|
||||
exp_avg_copy,
|
||||
exp_avg_sq_copy,
|
||||
loss_scale,
|
||||
use_adamw,
|
||||
)
|
||||
|
||||
if loss_scale > 0:
|
||||
p_grad.div_(loss_scale)
|
||||
|
||||
var = p_data_copy - p_data
|
||||
data_diff = torch.max(torch.abs(var))
|
||||
threshold = 2e-3 if grad_dtype else 1e-4
|
||||
self.assertLess(
|
||||
data_diff,
|
||||
threshold,
|
||||
f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
|
||||
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} loss_scale {loss_scale} grad_dtype {grad_dtype}",
|
||||
)
|
||||
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
|
||||
self.assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
|
||||
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
||||
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
||||
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
||||
self.assertTrue(
|
||||
max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}"
|
||||
)
|
||||
|
||||
def test_cpu_adam(self):
|
||||
lr = 0.9
|
||||
eps = 1e-6
|
||||
weight_decay = 0
|
||||
for use_adamw in [False, True]:
|
||||
for shape in [(1023, ), (32, 1024)]:
|
||||
for step in range(1, 2):
|
||||
for lr in [0.01]:
|
||||
for eps in [1e-8]:
|
||||
for beta1 in [0.9]:
|
||||
for beta2 in [0.999]:
|
||||
for weight_decay in [0.001]:
|
||||
for grad_dtype in [torch.half, torch.float]:
|
||||
for loss_scale in [-1, 2 ** 5]:
|
||||
self.check_res(
|
||||
step,
|
||||
lr,
|
||||
eps,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
shape,
|
||||
grad_dtype,
|
||||
loss_scale,
|
||||
use_adamw,
|
||||
cpu_adam,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = Test()
|
||||
test.test_cpu_adam()
|
||||
print('All is well.')
|
Loading…
Reference in New Issue