@ -35,23 +35,19 @@ SOFTWARE
void Adam_Optimizer : : Step_1 ( float * _params , float * grads , float * _exp_avg ,
float * _exp_avg_sq , size_t _param_size ,
bool param_half_precision , bool grad_half_precision ,
float loss_scale ) {
size_t rounded_size = 0 ;
bool momentum_half_precision ,
bool variance_half_precision , float loss_scale ) {
size_t rounded_size = ROUND_DOWN ( _param_size , SIMD_WIDTH ) ;
float betta1_minus1 = 1 - _betta1 ;
float betta2_minus1 = 1 - _betta2 ;
float step_size = - 1 * _alpha / _bias_correction1 ;
float w_decay = - 1 * _alpha * _weight_decay ;
__half * params_cast_h = NULL ;
__half * grads_cast_h = NULL ;
if ( param_half_precision ) {
params_cast_h = reinterpret_cast < __half * > ( _params ) ;
}
if ( grad_half_precision ) {
grads_cast_h = reinterpret_cast < __half * > ( grads ) ;
}
__half * params_cast_h = reinterpret_cast < __half * > ( _params ) ;
__half * grads_cast_h = reinterpret_cast < __half * > ( grads ) ;
__half * momentum_cast_h = reinterpret_cast < __half * > ( _exp_avg ) ;
__half * variance_cast_h = reinterpret_cast < __half * > ( _exp_avg_sq ) ;
# if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4 ;
@ -77,7 +73,6 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
if ( _weight_decay > 0 )
weight_decay_4 . data =
( _adamw_mode ? SIMD_SET ( w_decay ) : SIMD_SET ( _weight_decay ) ) ;
rounded_size = ROUND_DOWN ( _param_size , SIMD_WIDTH ) ;
for ( size_t t = 0 ; t < rounded_size ; t + = TILE ) {
size_t copy_size = TILE ;
@ -87,28 +82,23 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
# pragma omp parallel for
for ( size_t i = t ; i < offset ; i + = SIMD_WIDTH ) {
AVX_Data grad_4 ;
if ( grad_half_precision ) {
grad_4 . data = SIMD_LOAD_HALF ( grads_cast_h + i ) ;
} else {
grad_4 . data = SIMD_LOAD ( grads + i ) ;
}
this - > simd_load ( grad_half_precision , grads + i , grads_cast_h + i , grad_4 ) ;
if ( loss_scale > 0 ) {
AVX_Data loss_scale_vec ;
loss_scale_vec . data = SIMD_SET ( loss_scale ) ;
grad_4 . data = SIMD_DIV ( grad_4 . data , loss_scale_vec . data ) ;
}
AVX_Data momentum_4 ;
momentum_4 . data = SIMD_LOAD ( _exp_avg + i ) ;
this - > simd_load ( momentum_half_precision , _exp_avg + i ,
momentum_cast_h + i , momentum_4 ) ;
AVX_Data variance_4 ;
variance_4 . data = SIMD_LOAD ( _exp_avg_sq + i ) ;
this - > simd_load ( variance_half_precision , _exp_avg_sq + i ,
variance_cast_h + i , variance_4 ) ;
AVX_Data param_4 ;
if ( param_half_precision ) {
param_4 . data = SIMD_LOAD_HALF ( params_cast_h + i ) ;
} else {
param_4 . data = SIMD_LOAD ( _params + i ) ;
}
this - > simd_load ( param_half_precision , _params + i , params_cast_h + i ,
param_4 ) ;
if ( _weight_decay > 0 & & ! _adamw_mode ) {
grad_4 . data = SIMD_FMA ( param_4 . data , weight_decay_4 . data , grad_4 . data ) ;
@ -130,13 +120,12 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
}
param_4 . data = SIMD_FMA ( grad_4 . data , step_size_4 . data , param_4 . data ) ;
if ( param_half_precision ) {
SIMD_STORE_HALF ( ( float * ) ( params_cast_h + i ) , param_4 . data ) ;
} else {
SIMD_STORE ( _params + i , param_4 . data ) ;
}
SIMD_STORE ( _exp_avg + i , momentum_4 . data ) ;
SIMD_STORE ( _exp_avg_sq + i , variance_4 . data ) ;
this - > simd_store ( param_half_precision , _params + i , params_cast_h + i ,
param_4 ) ;
this - > simd_store ( momentum_half_precision , _exp_avg + i ,
momentum_cast_h + i , momentum_4 ) ;
this - > simd_store ( variance_half_precision , _exp_avg_sq + i ,
variance_cast_h + i , variance_4 ) ;
}
}
# endif
@ -154,8 +143,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
}
float param =
param_half_precision ? ( float ) params_cast_h [ k ] : _params [ k ] ;
float momentum = _exp_avg [ k ] ;
float variance = _exp_avg_sq [ k ] ;
float momentum =
momentum_half_precision ? ( float ) momentum_cast_h [ k ] : _exp_avg [ k ] ;
float variance = variance_half_precision ? ( float ) variance_cast_h [ k ]
: _exp_avg_sq [ k ] ;
if ( _weight_decay > 0 & & ! _adamw_mode ) {
grad = param * _weight_decay + grad ;
}
@ -178,8 +169,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
params_cast_h [ k ] = ( __half ) param ;
else
_params [ k ] = param ;
_exp_avg [ k ] = momentum ;
_exp_avg_sq [ k ] = variance ;
if ( momentum_half_precision )
momentum_cast_h [ k ] = ( __half ) ( momentum ) ;
else
_exp_avg [ k ] = momentum ;
if ( variance_half_precision )
variance_cast_h [ k ] = ( __half ) ( variance ) ;
else
_exp_avg_sq [ k ] = variance ;
}
}
}
@ -188,17 +185,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
void Adam_Optimizer : : Step_4 ( float * _params , float * grads , float * _exp_avg ,
float * _exp_avg_sq , size_t _param_size ,
bool param_half_precision , bool grad_half_precision ,
float loss_scale ) {
size_t rounded_size = 0 ;
bool momentum_half_precision ,
bool variance_half_precision , float loss_scale ) {
size_t rounded_size = ROUND_DOWN ( _param_size , SIMD_WIDTH * 4 ) ;
__half * params_cast_h = NULL ;
__half * grads_cast_h = NULL ;
if ( param_half_precision ) {
params_cast_h = reinterpret_cast < __half * > ( _params ) ;
}
if ( grad_half_precision ) {
grads_cast_h = reinterpret_cast < __half * > ( grads ) ;
}
__half * params_cast_h = reinterpret_cast < __half * > ( _params ) ;
__half * grads_cast_h = reinterpret_cast < __half * > ( grads ) ;
__half * momentum_cast_h = reinterpret_cast < __half * > ( _exp_avg ) ;
__half * variance_cast_h = reinterpret_cast < __half * > ( _exp_avg_sq ) ;
# if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4 ;
@ -228,7 +222,6 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
if ( _weight_decay > 0 )
weight_decay_4 . data =
( _adamw_mode ? SIMD_SET ( w_decay ) : SIMD_SET ( _weight_decay ) ) ;
rounded_size = ROUND_DOWN ( _param_size , SIMD_WIDTH * 4 ) ;
for ( size_t t = 0 ; t < rounded_size ; t + = TILE ) {
size_t copy_size = TILE ;
@ -243,26 +236,21 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
AVX_Data param_4 [ 4 ] ;
# pragma unroll 4
for ( int j = 0 ; j < 4 ; j + + ) {
if ( grad_half_precision ) {
grad_4 [ j ] . data = SIMD_LOAD_HALF ( grads_cast_h + i + SIMD_WIDTH * j ) ;
} else {
grad_4 [ j ] . data = SIMD_LOAD ( grads + i + SIMD_WIDTH * j ) ;
}
this - > simd_load ( grad_half_precision , grads + i + SIMD_WIDTH * j ,
grads_cast_h + i + SIMD_WIDTH * j , grad_4 [ j ] ) ;
if ( loss_scale > 0 ) {
AVX_Data loss_scale_vec ;
loss_scale_vec . data = SIMD_SET ( loss_scale ) ;
grad_4 [ j ] . data = SIMD_DIV ( grad_4 [ j ] . data , loss_scale_vec . data ) ;
}
momentum_4 [ j ] . data = SIMD_LOAD ( _exp_avg + i + SIMD_WIDTH * j ) ;
variance_4 [ j ] . data = SIMD_LOAD ( _exp_avg_sq + i + SIMD_WIDTH * j ) ;
if ( param_half_precision ) {
param_4 [ j ] . data = SIMD_LOAD_HALF ( params_cast_h + i + SIMD_WIDTH * j ) ;
} else {
param_4 [ j ] . data = SIMD_LOAD ( _params + i + SIMD_WIDTH * j ) ;
}
this - > simd_load ( momentum_half_precision , _exp_avg + i + SIMD_WIDTH * j ,
momentum_cast_h + i + SIMD_WIDTH * j , momentum_4 [ j ] ) ;
this - > simd_load ( variance_half_precision ,
_exp_avg_sq + i + SIMD_WIDTH * j ,
variance_cast_h + i + SIMD_WIDTH * j , variance_4 [ j ] ) ;
this - > simd_load ( param_half_precision , _params + i + SIMD_WIDTH * j ,
params_cast_h + i + SIMD_WIDTH * j , param_4 [ j ] ) ;
if ( _weight_decay > 0 & & ! _adamw_mode ) {
grad_4 [ j ] . data =
@ -285,14 +273,13 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
}
param_4 [ j ] . data =
SIMD_FMA ( grad_4 [ j ] . data , step_size_4 . data , param_4 [ j ] . data ) ;
if ( param_half_precision ) {
SIMD_STORE_HALF ( ( float * ) ( params_cast_h + i + SIMD_WIDTH * j ) ,
param_4 [ j ] . data ) ;
} else {
SIMD_STORE ( _params + i + SIMD_WIDTH * j , param_4 [ j ] . data ) ;
}
SIMD_STORE ( _exp_avg + i + SIMD_WIDTH * j , momentum_4 [ j ] . data ) ;
SIMD_STORE ( _exp_avg_sq + i + SIMD_WIDTH * j , variance_4 [ j ] . data ) ;
this - > simd_store ( param_half_precision , _params + i + SIMD_WIDTH * j ,
params_cast_h + i + SIMD_WIDTH * j , param_4 [ j ] ) ;
this - > simd_store ( momentum_half_precision , _exp_avg + i + SIMD_WIDTH * j ,
momentum_cast_h + i + SIMD_WIDTH * j , momentum_4 [ j ] ) ;
this - > simd_store ( variance_half_precision ,
_exp_avg_sq + i + SIMD_WIDTH * j ,
variance_cast_h + i + SIMD_WIDTH * j , variance_4 [ j ] ) ;
}
}
}
@ -302,24 +289,26 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
: _params + rounded_size ) ,
( grad_half_precision ? ( float * ) ( grads_cast_h + rounded_size )
: grads + rounded_size ) ,
( _exp_avg + rounded_size ) , ( _exp_avg_sq + rounded_size ) ,
( momentum_half_precision ? ( float * ) ( momentum_cast_h + rounded_size )
: _exp_avg + rounded_size ) ,
( variance_half_precision ? ( float * ) ( variance_cast_h + rounded_size )
: _exp_avg_sq + rounded_size ) ,
( _param_size - rounded_size ) , param_half_precision ,
grad_half_precision , loss_scale ) ;
grad_half_precision , momentum_half_precision ,
variance_half_precision , loss_scale ) ;
}
void Adam_Optimizer : : Step_8 ( float * _params , float * grads , float * _exp_avg ,
float * _exp_avg_sq , size_t _param_size ,
bool param_half_precision , bool grad_half_precision ,
float loss_scale ) {
size_t rounded_size = 0 ;
__half * params_cast_h = NULL ;
__half * grads_cast_h = NULL ;
if ( param_half_precision ) {
params_cast_h = reinterpret_cast < __half * > ( _params ) ;
}
if ( grad_half_precision ) {
grads_cast_h = reinterpret_cast < __half * > ( grads ) ;
}
bool momentum_half_precision ,
bool variance_half_precision , float loss_scale ) {
size_t rounded_size = ROUND_DOWN ( _param_size , SIMD_WIDTH * 8 ) ;
__half * params_cast_h = reinterpret_cast < __half * > ( _params ) ;
__half * grads_cast_h = reinterpret_cast < __half * > ( grads ) ;
__half * momentum_cast_h = reinterpret_cast < __half * > ( _exp_avg ) ;
__half * variance_cast_h = reinterpret_cast < __half * > ( _exp_avg_sq ) ;
# if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4 ;
betta1_4 . data = SIMD_SET ( _betta1 ) ;
@ -348,7 +337,6 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
if ( _weight_decay > 0 )
weight_decay_4 . data =
( _adamw_mode ? SIMD_SET ( w_decay ) : SIMD_SET ( _weight_decay ) ) ;
rounded_size = ROUND_DOWN ( _param_size , SIMD_WIDTH * 8 ) ;
for ( size_t t = 0 ; t < rounded_size ; t + = TILE ) {
size_t copy_size = TILE ;
@ -363,26 +351,21 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
AVX_Data param_4 [ 8 ] ;
# pragma unroll 8
for ( int j = 0 ; j < 8 ; j + + ) {
if ( grad_half_precision ) {
grad_4 [ j ] . data = SIMD_LOAD_HALF ( grads_cast_h + i + SIMD_WIDTH * j ) ;
} else {
grad_4 [ j ] . data = SIMD_LOAD ( grads + i + SIMD_WIDTH * j ) ;
}
this - > simd_load ( grad_half_precision , grads + i + SIMD_WIDTH * j ,
grads_cast_h + i + SIMD_WIDTH * j , grad_4 [ j ] ) ;
if ( loss_scale > 0 ) {
AVX_Data loss_scale_vec ;
loss_scale_vec . data = SIMD_SET ( loss_scale ) ;
grad_4 [ j ] . data = SIMD_DIV ( grad_4 [ j ] . data , loss_scale_vec . data ) ;
}
momentum_4 [ j ] . data = SIMD_LOAD ( _exp_avg + i + SIMD_WIDTH * j ) ;
variance_4 [ j ] . data = SIMD_LOAD ( _exp_avg_sq + i + SIMD_WIDTH * j ) ;
if ( param_half_precision ) {
param_4 [ j ] . data = SIMD_LOAD_HALF ( params_cast_h + i + SIMD_WIDTH * j ) ;
} else {
param_4 [ j ] . data = SIMD_LOAD ( _params + i + SIMD_WIDTH * j ) ;
}
this - > simd_load ( momentum_half_precision , _exp_avg + i + SIMD_WIDTH * j ,
momentum_cast_h + i + SIMD_WIDTH * j , momentum_4 [ j ] ) ;
this - > simd_load ( variance_half_precision ,
_exp_avg_sq + i + SIMD_WIDTH * j ,
variance_cast_h + i + SIMD_WIDTH * j , variance_4 [ j ] ) ;
this - > simd_load ( param_half_precision , _params + i + SIMD_WIDTH * j ,
params_cast_h + i + SIMD_WIDTH * j , param_4 [ j ] ) ;
if ( _weight_decay > 0 & & ! _adamw_mode ) {
grad_4 [ j ] . data =
@ -405,15 +388,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
param_4 [ j ] . data =
SIMD_FMA ( grad_4 [ j ] . data , step_size_4 . data , param_4 [ j ] . data ) ;
if ( param_half_precision ) {
SIMD_STORE_HALF ( ( float * ) ( params_cast_h + i + SIMD_WIDTH * j ) ,
param_4 [ j ] . data ) ;
} else {
SIMD_STORE ( _params + i + SIMD_WIDTH * j , param_4 [ j ] . data ) ;
}
SIMD_STORE ( _exp_avg + i + ( SIMD_WIDTH * j ) , momentum_4 [ j ] . data ) ;
SIMD_STORE ( _exp_avg_sq + i + ( SIMD_WIDTH * j ) , variance_4 [ j ] . data ) ;
this - > simd_store ( param_half_precision , _params + i + SIMD_WIDTH * j ,
params_cast_h + i + SIMD_WIDTH * j , param_4 [ j ] ) ;
this - > simd_store ( momentum_half_precision , _exp_avg + i + SIMD_WIDTH * j ,
momentum_cast_h + i + SIMD_WIDTH * j , momentum_4 [ j ] ) ;
this - > simd_store ( variance_half_precision ,
_exp_avg_sq + i + SIMD_WIDTH * j ,
variance_cast_h + i + SIMD_WIDTH * j , variance_4 [ j ] ) ;
}
}
}
@ -423,9 +404,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
: _params + rounded_size ) ,
( grad_half_precision ? ( float * ) ( grads_cast_h + rounded_size )
: grads + rounded_size ) ,
( _exp_avg + rounded_size ) , ( _exp_avg_sq + rounded_size ) ,
( momentum_half_precision ? ( float * ) ( momentum_cast_h + rounded_size )
: _exp_avg + rounded_size ) ,
( variance_half_precision ? ( float * ) ( variance_cast_h + rounded_size )
: _exp_avg_sq + rounded_size ) ,
( _param_size - rounded_size ) , param_half_precision ,
grad_half_precision , loss_scale ) ;
grad_half_precision , momentum_half_precision ,
variance_half_precision , loss_scale ) ;
}
void Adam_Optimizer : : step ( size_t step , float lr , float beta1 , float beta2 ,
@ -447,7 +432,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
this - > update_state ( lr , epsilon , weight_decay , bias_correction ) ;
this - > Step_8 ( params_ptr , grads_ptr , exp_avg_ptr , exp_avg_sq_ptr ,
params_c . numel ( ) , ( params . options ( ) . dtype ( ) = = at : : kHalf ) ,
( grads . options ( ) . dtype ( ) = = at : : kHalf ) , loss_scale ) ;
( grads . options ( ) . dtype ( ) = = at : : kHalf ) ,
( exp_avg . options ( ) . dtype ( ) = = at : : kHalf ) ,
( exp_avg_sq . options ( ) . dtype ( ) = = at : : kHalf ) , loss_scale ) ;
}
namespace py = pybind11 ;