Browse Source

[kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)

* [kernel] support pure fp16 for cpu adam (#4896)

* [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919)

* [kernel] fix cpu adam

* [test] update gemini optim test
pull/4934/head
Hongxin Liu 1 year ago committed by GitHub
parent
commit
4f68b3f10c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 201
      colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
  2. 41
      colossalai/kernel/cuda_native/csrc/cpu_adam.h
  3. 3
      colossalai/nn/optimizer/cpu_adam.py
  4. 3
      colossalai/nn/optimizer/hybrid_adam.py
  5. 7
      tests/test_optimizer/test_adam_kernel.py
  6. 2
      tests/test_optimizer/test_adam_optim.py
  7. 12
      tests/test_zero/test_gemini/test_grad_clip.py
  8. 15
      tests/test_zero/test_gemini/test_optim.py

201
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp

@ -35,23 +35,19 @@ SOFTWARE
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size, float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision, bool param_half_precision, bool grad_half_precision,
float loss_scale) { bool momentum_half_precision,
size_t rounded_size = 0; bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
float betta1_minus1 = 1 - _betta1; float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2; float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1; float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay; float w_decay = -1 * _alpha * _weight_decay;
__half *params_cast_h = NULL; __half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = NULL; __half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
if (param_half_precision) { __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4; AVX_Data betta1_4;
@ -77,7 +73,6 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
if (_weight_decay > 0) if (_weight_decay > 0)
weight_decay_4.data = weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
@ -87,28 +82,23 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
#pragma omp parallel for #pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) { for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4; AVX_Data grad_4;
if (grad_half_precision) { this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4);
grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i);
} else {
grad_4.data = SIMD_LOAD(grads + i);
}
if (loss_scale > 0) { if (loss_scale > 0) {
AVX_Data loss_scale_vec; AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale); loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
} }
AVX_Data momentum_4; AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i); this->simd_load(momentum_half_precision, _exp_avg + i,
momentum_cast_h + i, momentum_4);
AVX_Data variance_4; AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i); this->simd_load(variance_half_precision, _exp_avg_sq + i,
variance_cast_h + i, variance_4);
AVX_Data param_4; AVX_Data param_4;
if (param_half_precision) { this->simd_load(param_half_precision, _params + i, params_cast_h + i,
param_4.data = SIMD_LOAD_HALF(params_cast_h + i); param_4);
} else {
param_4.data = SIMD_LOAD(_params + i);
}
if (_weight_decay > 0 && !_adamw_mode) { if (_weight_decay > 0 && !_adamw_mode) {
grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
@ -130,13 +120,12 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
} }
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
if (param_half_precision) { this->simd_store(param_half_precision, _params + i, params_cast_h + i,
SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data); param_4);
} else { this->simd_store(momentum_half_precision, _exp_avg + i,
SIMD_STORE(_params + i, param_4.data); momentum_cast_h + i, momentum_4);
} this->simd_store(variance_half_precision, _exp_avg_sq + i,
SIMD_STORE(_exp_avg + i, momentum_4.data); variance_cast_h + i, variance_4);
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
} }
} }
#endif #endif
@ -154,8 +143,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
} }
float param = float param =
param_half_precision ? (float)params_cast_h[k] : _params[k]; param_half_precision ? (float)params_cast_h[k] : _params[k];
float momentum = _exp_avg[k]; float momentum =
float variance = _exp_avg_sq[k]; momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
float variance = variance_half_precision ? (float)variance_cast_h[k]
: _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) { if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad; grad = param * _weight_decay + grad;
} }
@ -178,8 +169,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
params_cast_h[k] = (__half)param; params_cast_h[k] = (__half)param;
else else
_params[k] = param; _params[k] = param;
_exp_avg[k] = momentum; if (momentum_half_precision)
_exp_avg_sq[k] = variance; momentum_cast_h[k] = (__half)(momentum);
else
_exp_avg[k] = momentum;
if (variance_half_precision)
variance_cast_h[k] = (__half)(variance);
else
_exp_avg_sq[k] = variance;
} }
} }
} }
@ -188,17 +185,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size, float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision, bool param_half_precision, bool grad_half_precision,
float loss_scale) { bool momentum_half_precision,
size_t rounded_size = 0; bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
__half *params_cast_h = NULL; __half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = NULL; __half *grads_cast_h = reinterpret_cast<__half *>(grads);
if (param_half_precision) { __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
params_cast_h = reinterpret_cast<__half *>(_params); __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4; AVX_Data betta1_4;
@ -228,7 +222,6 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
if (_weight_decay > 0) if (_weight_decay > 0)
weight_decay_4.data = weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
@ -243,26 +236,21 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
AVX_Data param_4[4]; AVX_Data param_4[4];
#pragma unroll 4 #pragma unroll 4
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
if (grad_half_precision) { this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
} else {
grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j);
}
if (loss_scale > 0) { if (loss_scale > 0) {
AVX_Data loss_scale_vec; AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale); loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
} }
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); this->simd_load(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
if (param_half_precision) { variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
} else { params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
}
if (_weight_decay > 0 && !_adamw_mode) { if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data = grad_4[j].data =
@ -285,14 +273,13 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
} }
param_4[j].data = param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
if (param_half_precision) { this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
param_4[j].data); this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
} else { momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); this->simd_store(variance_half_precision,
} _exp_avg_sq + i + SIMD_WIDTH * j,
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
} }
} }
} }
@ -302,24 +289,26 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
: _params + rounded_size), : _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size) (grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size), : grads + rounded_size),
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision, (_param_size - rounded_size), param_half_precision,
grad_half_precision, loss_scale); grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
} }
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size, float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision, bool param_half_precision, bool grad_half_precision,
float loss_scale) { bool momentum_half_precision,
size_t rounded_size = 0; bool variance_half_precision, float loss_scale) {
__half *params_cast_h = NULL; size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
__half *grads_cast_h = NULL; __half *params_cast_h = reinterpret_cast<__half *>(_params);
if (param_half_precision) { __half *grads_cast_h = reinterpret_cast<__half *>(grads);
params_cast_h = reinterpret_cast<__half *>(_params); __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
} __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4; AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1); betta1_4.data = SIMD_SET(_betta1);
@ -348,7 +337,6 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
if (_weight_decay > 0) if (_weight_decay > 0)
weight_decay_4.data = weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
@ -363,26 +351,21 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
AVX_Data param_4[8]; AVX_Data param_4[8];
#pragma unroll 8 #pragma unroll 8
for (int j = 0; j < 8; j++) { for (int j = 0; j < 8; j++) {
if (grad_half_precision) { this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
} else {
grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j);
}
if (loss_scale > 0) { if (loss_scale > 0) {
AVX_Data loss_scale_vec; AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale); loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
} }
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); this->simd_load(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
if (param_half_precision) { variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
} else { params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
}
if (_weight_decay > 0 && !_adamw_mode) { if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data = grad_4[j].data =
@ -405,15 +388,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
param_4[j].data = param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
if (param_half_precision) { this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
param_4[j].data); this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
} else { momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); this->simd_store(variance_half_precision,
} _exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data);
} }
} }
} }
@ -423,9 +404,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
: _params + rounded_size), : _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size) (grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size), : grads + rounded_size),
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision, (_param_size - rounded_size), param_half_precision,
grad_half_precision, loss_scale); grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
} }
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
@ -447,7 +432,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
this->update_state(lr, epsilon, weight_decay, bias_correction); this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf), params_c.numel(), (params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf), loss_scale); (grads.options().dtype() == at::kHalf),
(exp_avg.options().dtype() == at::kHalf),
(exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
} }
namespace py = pybind11; namespace py = pybind11;

41
colossalai/kernel/cuda_native/csrc/cpu_adam.h

@ -50,9 +50,9 @@ SOFTWARE
#define SIMD_DIV(x, y) _mm512_div_ps(x, y) #define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \ #define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \ #define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \ _mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__) #elif defined(__AVX256__) or defined(__AVX2__)
#define SIMD_WIDTH 8 #define SIMD_WIDTH 8
@ -66,9 +66,9 @@ SOFTWARE
#define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \ #define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \ _mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) d, _MM_FROUND_TO_NEAREST_INT)))
#endif #endif
@ -83,11 +83,12 @@ union AVX_Data {
#endif #endif
#define STEP(SPAN) \ #define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ void Step_##SPAN( \
float *_exp_avg_sq, size_t _param_size, \ float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
bool param_half_precision = false, \ size_t _param_size, bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1); bool grad_half_precision = false, bool momentum_half_precision = false, \
bool variance_half_precision = false, float loss_scale = -1);
class Adam_Optimizer { class Adam_Optimizer {
public: public:
@ -141,6 +142,24 @@ class Adam_Optimizer {
} }
} }
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
data.data = SIMD_LOAD_HALF(h_ptr);
} else {
data.data = SIMD_LOAD(ptr);
}
}
inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
SIMD_STORE_HALF(h_ptr, data.data);
} else {
SIMD_STORE(ptr, data.data);
}
}
void step(size_t step, float lr, float beta1, float beta2, float epsilon, void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params, float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg, torch::Tensor &grads, torch::Tensor &exp_avg,

3
colossalai/nn/optimizer/cpu_adam.py

@ -146,8 +146,7 @@ class CPUAdam(NVMeOptimizer):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq") self._pre_update(p, "exp_avg", "exp_avg_sq")
# FIXME(ver217): CPU adam kernel only supports fp32 states now if p.grad.dtype is torch.bfloat16:
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
# cpu adam kernel does not support bf16 now # cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"]

3
colossalai/nn/optimizer/hybrid_adam.py

@ -122,8 +122,7 @@ class HybridAdam(CPUAdam):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq") self._pre_update(p, "exp_avg", "exp_avg_sq")
# FIXME(ver217): CPU adam kernel only supports fp32 states now if p.grad.dtype is torch.bfloat16:
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
# cpu adam kernel does not support bf16 now # cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"]

7
tests/test_optimizer/test_adam_kernel.py

@ -13,9 +13,7 @@ from colossalai.utils import get_current_device, multi_tensor_applier
_FUSED_ALLOWED_P_G_TYPES = [ _FUSED_ALLOWED_P_G_TYPES = [
(torch.float, torch.half), (torch.float, torch.half),
(torch.float, torch.float), (torch.float, torch.float),
(torch.half, torch.float),
(torch.half, torch.half), (torch.half, torch.half),
(torch.bfloat16, torch.float),
(torch.float, torch.bfloat16), (torch.float, torch.bfloat16),
(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.bfloat16),
] ]
@ -23,7 +21,6 @@ _FUSED_ALLOWED_P_G_TYPES = [
_CPU_ALLOWED_P_G_TYPES = [ _CPU_ALLOWED_P_G_TYPES = [
(torch.float, torch.half), (torch.float, torch.half),
(torch.float, torch.float), (torch.float, torch.float),
(torch.half, torch.float),
(torch.half, torch.half), (torch.half, torch.half),
] ]
@ -138,8 +135,8 @@ def check_adam_kernel(
master_exp_avg_sq = torch.zeros_like(master_p) master_exp_avg_sq = torch.zeros_like(master_p)
p = master_p.clone().to(p_dtype) p = master_p.clone().to(p_dtype)
g = master_g.clone().to(g_dtype) g = master_g.clone().to(g_dtype)
exp_avg = master_exp_avg.clone() exp_avg = master_exp_avg.clone().to(p_dtype)
exp_avg_sq = master_exp_avg_sq.clone() exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)
for step in range(1, 1 + n_steps): for step in range(1, 1 + n_steps):
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)

2
tests/test_optimizer/test_adam_optim.py

@ -21,8 +21,6 @@ _ALLOWED_P_G_TYPES = [
(torch.float, torch.float), # pure fp32 (torch.float, torch.float), # pure fp32
(torch.float, torch.half), # fp16 amp (torch.float, torch.half), # fp16 amp
(torch.float, torch.bfloat16), # bfloat16 amp (torch.float, torch.bfloat16), # bfloat16 amp
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
] ]
N_STEPS = 3 N_STEPS = 3

12
tests/test_zero/test_gemini/test_grad_clip.py

@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"]) @parameterize("model_name", ["gpt2"])
def exam_grad_clipping(placement_config, model_name: str): @parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
set_seed(1912) set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str):
chunk_config_dict=config_dict, chunk_config_dict=config_dict,
chunk_init_device=init_device, chunk_init_device=init_device,
pin_memory=True, pin_memory=True,
master_weights=master_weights,
**placement_config, **placement_config,
) )
@ -103,7 +105,10 @@ def exam_grad_clipping(placement_config, model_name: str):
torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
loss = run_fwd_bwd(model, data, label, criterion, zero_optim) loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
assert_close(torch_loss, loss)
# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss)
import apex.amp as apex_amp import apex.amp as apex_amp
@ -111,7 +116,8 @@ def exam_grad_clipping(placement_config, model_name: str):
torch_optim.step() torch_optim.step()
zero_optim.step() zero_optim.step()
check_param(model, torch_model) if master_weights:
check_param(model, torch_model)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):

15
tests/test_zero/test_gemini/test_optim.py

@ -70,12 +70,14 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", TEST_MODELS) @parameterize("model_name", TEST_MODELS)
@parameterize("mixed_precision", [torch.half, torch.bfloat16]) @parameterize("mixed_precision", [torch.half, torch.bfloat16])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype): @parameterize("master_weights", [True, False])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
# apex no master weights leads to nan, so we don't use it
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
@ -90,7 +92,9 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = False config_dict[world_size]["keep_gathered"] = False
model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision) model = GeminiDDP(
model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights
)
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
@ -109,12 +113,15 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss, rtol=rtol, atol=atol) # as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
zero_optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()
check_param(model, torch_model, mixed_precision) if master_weights:
check_param(model, torch_model, mixed_precision)
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)

Loading…
Cancel
Save