diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index 758722a8f..023e653d3 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -21,11 +21,11 @@ SOFTWARE */ #pragma once +#include +#include #include #include #include -#include -#include #if (__x86_64__ || __i386__) #include @@ -48,8 +48,11 @@ SOFTWARE #define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) #define SIMD_SQRT(x) _mm512_sqrt_ps(x) #define SIMD_DIV(x, y) _mm512_div_ps(x, y) -#define SIMD_LOAD_HALF(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(x))) -#define SIMD_STORE_HALF(x, d) _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_LOAD_HALF(x) \ + _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) +#define SIMD_STORE_HALF(x, d) \ + _mm256_store_ps( \ + x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #elif defined(__AVX256__) or defined(__AVX2__) #define SIMD_WIDTH 8 @@ -62,102 +65,89 @@ SOFTWARE #define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) #define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_DIV(x, y) _mm256_div_ps(x, y) -#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x))) -#define SIMD_STORE_HALF(x, d) _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) +#define SIMD_STORE_HALF(x, d) \ + _mm_store_ps( \ + x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #endif union AVX_Data { #if defined(__AVX512__) - __m512 data; + __m512 data; #elif defined(__AVX256__) or defined(__AVX2__) - __m256 data; + __m256 data; #endif - // float data_f[16]; + // float data_f[16]; }; #endif - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - bool param_half_precision = false, \ - bool grad_half_precision = false, \ - float loss_scale = -1); +#define STEP(SPAN) \ + void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ + float *_exp_avg_sq, size_t _param_size, \ + bool param_half_precision = false, \ + bool grad_half_precision = false, float loss_scale = -1); class Adam_Optimizer { public: - Adam_Optimizer(float alpha = 1e-3, - float betta1 = 0.9, - float betta2 = 0.999, - float eps = 1e-8, - float weight_decay = 0, - bool adamw_mode = true) - : _alpha(alpha), - _betta1(betta1), - _betta2(betta2), - _eps(eps), - _weight_decay(weight_decay), - _betta1_t(1.0), - _betta2_t(1.0), - _step(0), - _adamw_mode(adamw_mode){} - ~Adam_Optimizer(){} - - STEP(1) - STEP(4) - STEP(8) - inline void IncrementStep(size_t step, float beta1, float beta2) - { - if (beta1 != _betta1 || beta2 != _betta2) { - _step = step; - _betta1 = beta1; - _betta2 = beta2; - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); - } else { - _step++; - if (_step != step) { - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); - _step = step; - } else { - _betta1_t *= _betta1; - _betta2_t *= _betta2; - } - } - } - inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) - { - _alpha = lr; - _eps = epsilon; - _weight_decay = weight_decay; + Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, + float eps = 1e-8, float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), _betta1(betta1), _betta2(betta2), _eps(eps), + _weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), _step(0), + _adamw_mode(adamw_mode) {} + ~Adam_Optimizer() {} - _bias_correction1 = 1.0f; - _bias_correction2 = 1.0f; - if (bias_correction == 1) { - _bias_correction1 = 1 - _betta1_t; - _bias_correction2 = 1 / sqrt(1 - _betta2_t); - } + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step, float beta1, float beta2) { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } } + } + inline void update_state(float lr, float epsilon, float weight_decay, + bool bias_correction) { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } private: - float _alpha; - float _betta1; - float _betta2; - float _eps; - float _weight_decay; + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; - float _betta1_t; - float _betta2_t; - size_t _step; + float _betta1_t; + float _betta2_t; + size_t _step; - float _bias_correction1; - float _bias_correction2; + float _bias_correction1; + float _bias_correction2; - bool _adamw_mode; + bool _adamw_mode; };