diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp index feb612f9f..22bec7e27 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -20,12 +20,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE */ #include "cpu_adam.h" -#include + #include -#include #include #include #include + +#include +#include #include #include @@ -82,8 +84,7 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > rounded_size) - copy_size = rounded_size - t; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; size_t offset = copy_size + t; #pragma omp parallel for @@ -145,8 +146,7 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, if (_param_size > rounded_size) { for (size_t t = rounded_size; t < _param_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > _param_size) - copy_size = _param_size - t; + if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; #pragma omp parallel for @@ -235,8 +235,7 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > rounded_size) - copy_size = rounded_size - t; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; size_t offset = copy_size + t; #pragma omp parallel for @@ -321,7 +320,6 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3, s_optimizers[optimizer_id] = opt; if (should_log) { - std::string avx_type = ""; #if defined(__AVX512__) avx_type = "AVX512"; @@ -386,8 +384,7 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > rounded_size) - copy_size = rounded_size - t; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; size_t offset = copy_size + t; #pragma omp parallel for @@ -463,43 +460,29 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, grad_half_precision, loss_scale); } -int adam_step(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq, - float loss_scale) -{ - auto params_c = params.contiguous(); - auto grads_c = grads.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); +int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2, + float epsilon, float weight_decay, bool bias_correction, + torch::Tensor ¶ms, torch::Tensor &grads, + torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq, + float loss_scale) { + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - (params.options().dtype() == at::kHalf), - (grads.options().dtype() == at::kHalf), - loss_scale); + float *params_ptr = (float *)params_c.data_ptr(); + float *grads_ptr = (float *)grads_c.data_ptr(); + float *exp_avg_ptr = (float *)exp_avg_c.data_ptr(); + float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr(); + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, + params_c.numel(), (params.options().dtype() == at::kHalf), + (grads.options().dtype() == at::kHalf), loss_scale); - return 0; + return 0; } int destroy_adam_optimizer(int optimizer_id) {