mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936)
parent
44b6f8947b
commit
b6cc9313ef
|
@ -20,12 +20,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
SOFTWARE
|
SOFTWARE
|
||||||
*/
|
*/
|
||||||
#include "cpu_adam.h"
|
#include "cpu_adam.h"
|
||||||
#include <iostream>
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <memory>
|
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
@ -82,8 +84,7 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
|
||||||
|
|
||||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||||
size_t copy_size = TILE;
|
size_t copy_size = TILE;
|
||||||
if ((t + TILE) > rounded_size)
|
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||||
copy_size = rounded_size - t;
|
|
||||||
size_t offset = copy_size + t;
|
size_t offset = copy_size + t;
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
|
@ -145,8 +146,7 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
|
||||||
if (_param_size > rounded_size) {
|
if (_param_size > rounded_size) {
|
||||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||||
size_t copy_size = TILE;
|
size_t copy_size = TILE;
|
||||||
if ((t + TILE) > _param_size)
|
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||||
copy_size = _param_size - t;
|
|
||||||
size_t offset = copy_size + t;
|
size_t offset = copy_size + t;
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
|
@ -235,8 +235,7 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
|
||||||
|
|
||||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||||
size_t copy_size = TILE;
|
size_t copy_size = TILE;
|
||||||
if ((t + TILE) > rounded_size)
|
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||||
copy_size = rounded_size - t;
|
|
||||||
size_t offset = copy_size + t;
|
size_t offset = copy_size + t;
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
|
@ -321,7 +320,6 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
|
||||||
s_optimizers[optimizer_id] = opt;
|
s_optimizers[optimizer_id] = opt;
|
||||||
|
|
||||||
if (should_log) {
|
if (should_log) {
|
||||||
|
|
||||||
std::string avx_type = "";
|
std::string avx_type = "";
|
||||||
#if defined(__AVX512__)
|
#if defined(__AVX512__)
|
||||||
avx_type = "AVX512";
|
avx_type = "AVX512";
|
||||||
|
@ -386,8 +384,7 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
||||||
|
|
||||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||||
size_t copy_size = TILE;
|
size_t copy_size = TILE;
|
||||||
if ((t + TILE) > rounded_size)
|
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||||
copy_size = rounded_size - t;
|
|
||||||
size_t offset = copy_size + t;
|
size_t offset = copy_size + t;
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
|
@ -463,43 +460,29 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
||||||
grad_half_precision, loss_scale);
|
grad_half_precision, loss_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
int adam_step(int optimizer_id,
|
int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2,
|
||||||
size_t step,
|
float epsilon, float weight_decay, bool bias_correction,
|
||||||
float lr,
|
torch::Tensor ¶ms, torch::Tensor &grads,
|
||||||
float beta1,
|
torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq,
|
||||||
float beta2,
|
float loss_scale) {
|
||||||
float epsilon,
|
auto params_c = params.contiguous();
|
||||||
float weight_decay,
|
auto grads_c = grads.contiguous();
|
||||||
bool bias_correction,
|
auto exp_avg_c = exp_avg.contiguous();
|
||||||
torch::Tensor& params,
|
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||||
torch::Tensor& grads,
|
|
||||||
torch::Tensor& exp_avg,
|
|
||||||
torch::Tensor& exp_avg_sq,
|
|
||||||
float loss_scale)
|
|
||||||
{
|
|
||||||
auto params_c = params.contiguous();
|
|
||||||
auto grads_c = grads.contiguous();
|
|
||||||
auto exp_avg_c = exp_avg.contiguous();
|
|
||||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
|
||||||
|
|
||||||
float* params_ptr = (float*)params_c.data_ptr();
|
float *params_ptr = (float *)params_c.data_ptr();
|
||||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
float *grads_ptr = (float *)grads_c.data_ptr();
|
||||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
|
||||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
|
||||||
std::shared_ptr<Adam_Optimizer> opt =
|
std::shared_ptr<Adam_Optimizer> opt =
|
||||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||||
opt->IncrementStep(step, beta1, beta2);
|
opt->IncrementStep(step, beta1, beta2);
|
||||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||||
opt->Step_8(params_ptr,
|
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
|
||||||
grads_ptr,
|
params_c.numel(), (params.options().dtype() == at::kHalf),
|
||||||
exp_avg_ptr,
|
(grads.options().dtype() == at::kHalf), loss_scale);
|
||||||
exp_avg_sq_ptr,
|
|
||||||
params_c.numel(),
|
|
||||||
(params.options().dtype() == at::kHalf),
|
|
||||||
(grads.options().dtype() == at::kHalf),
|
|
||||||
loss_scale);
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int destroy_adam_optimizer(int optimizer_id) {
|
int destroy_adam_optimizer(int optimizer_id) {
|
||||||
|
|
Loading…
Reference in New Issue