[NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936)

pull/997/head
Geng Zhang 2022-05-13 15:37:30 +08:00 committed by binmakeswell
parent 44b6f8947b
commit b6cc9313ef
1 changed files with 29 additions and 46 deletions

View File

@ -20,12 +20,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE SOFTWARE
*/ */
#include "cpu_adam.h" #include "cpu_adam.h"
#include <iostream>
#include <math.h> #include <math.h>
#include <memory>
#include <omp.h> #include <omp.h>
#include <string.h> #include <string.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
@ -82,8 +84,7 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
@ -145,8 +146,7 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
if (_param_size > rounded_size) { if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) { for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > _param_size) if ((t + TILE) > _param_size) copy_size = _param_size - t;
copy_size = _param_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
@ -235,8 +235,7 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
@ -321,7 +320,6 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
s_optimizers[optimizer_id] = opt; s_optimizers[optimizer_id] = opt;
if (should_log) { if (should_log) {
std::string avx_type = ""; std::string avx_type = "";
#if defined(__AVX512__) #if defined(__AVX512__)
avx_type = "AVX512"; avx_type = "AVX512";
@ -386,8 +384,7 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
@ -463,43 +460,29 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
grad_half_precision, loss_scale); grad_half_precision, loss_scale);
} }
int adam_step(int optimizer_id, int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2,
size_t step, float epsilon, float weight_decay, bool bias_correction,
float lr, torch::Tensor &params, torch::Tensor &grads,
float beta1, torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq,
float beta2, float loss_scale) {
float epsilon, auto params_c = params.contiguous();
float weight_decay, auto grads_c = grads.contiguous();
bool bias_correction, auto exp_avg_c = exp_avg.contiguous();
torch::Tensor& params, auto exp_avg_sq_c = exp_avg_sq.contiguous();
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
float loss_scale)
{
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float* params_ptr = (float*)params_c.data_ptr(); float *params_ptr = (float *)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr(); float *grads_ptr = (float *)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt = std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]); std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2); opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction); opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
grads_ptr, params_c.numel(), (params.options().dtype() == at::kHalf),
exp_avg_ptr, (grads.options().dtype() == at::kHalf), loss_scale);
exp_avg_sq_ptr,
params_c.numel(),
(params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf),
loss_scale);
return 0; return 0;
} }
int destroy_adam_optimizer(int optimizer_id) { int destroy_adam_optimizer(int optimizer_id) {