diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index 023e653d3..74d2f1b17 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -48,10 +48,10 @@ SOFTWARE #define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) #define SIMD_SQRT(x) _mm512_sqrt_ps(x) #define SIMD_DIV(x, y) _mm512_div_ps(x, y) -#define SIMD_LOAD_HALF(x) \ +#define SIMD_LOAD_HALF(x) \ _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm256_store_ps( \ +#define SIMD_STORE_HALF(x, d) \ + _mm256_store_ps( \ x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #elif defined(__AVX256__) or defined(__AVX2__) @@ -66,8 +66,8 @@ SOFTWARE #define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm_store_ps( \ +#define SIMD_STORE_HALF(x, d) \ + _mm_store_ps( \ x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #endif @@ -83,19 +83,25 @@ union AVX_Data { #endif -#define STEP(SPAN) \ - void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ - float *_exp_avg_sq, size_t _param_size, \ - bool param_half_precision = false, \ +#define STEP(SPAN) \ + void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ + float *_exp_avg_sq, size_t _param_size, \ + bool param_half_precision = false, \ bool grad_half_precision = false, float loss_scale = -1); class Adam_Optimizer { -public: + public: Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, float eps = 1e-8, float weight_decay = 0, bool adamw_mode = true) - : _alpha(alpha), _betta1(betta1), _betta2(betta2), _eps(eps), - _weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), _step(0), + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), _adamw_mode(adamw_mode) {} ~Adam_Optimizer() {} @@ -135,7 +141,7 @@ public: } } -private: + private: float _alpha; float _betta1; float _betta2;