mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 supportpull/5072/head
Hongxin Liu
1 year ago
committed by
GitHub
46 changed files with 989 additions and 228 deletions
@ -0,0 +1,304 @@ |
|||||||
|
#include "cpu_adam_arm.h" |
||||||
|
|
||||||
|
void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg, |
||||||
|
void *_exp_avg_sq, size_t _param_size, |
||||||
|
at::ScalarType param_dtype, |
||||||
|
at::ScalarType grad_dtype, |
||||||
|
at::ScalarType exp_avg_dtype, |
||||||
|
at::ScalarType exp_avg_sq_dtype, float loss_scale) { |
||||||
|
size_t rounded_size = 0; |
||||||
|
#if defined(__aarch64__) |
||||||
|
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); |
||||||
|
#endif |
||||||
|
|
||||||
|
float betta1_minus1 = 1 - _betta1; |
||||||
|
float betta2_minus1 = 1 - _betta2; |
||||||
|
float step_size = -1 * _alpha / _bias_correction1; |
||||||
|
float w_decay = -1 * _alpha * _weight_decay; |
||||||
|
|
||||||
|
#if defined(__aarch64__) |
||||||
|
float32x4_t betta1_4 = simd_set(_betta1); |
||||||
|
float32x4_t betta2_4 = simd_set(_betta2); |
||||||
|
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); |
||||||
|
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); |
||||||
|
float32x4_t bias2_sqrt = simd_set(_bias_correction2); |
||||||
|
float32x4_t eps_4 = simd_set(_eps); |
||||||
|
float32x4_t step_size_4 = simd_set(step_size); |
||||||
|
float32x4_t weight_decay_4; |
||||||
|
if (_weight_decay > 0) { |
||||||
|
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); |
||||||
|
} |
||||||
|
for (size_t t = 0; t < rounded_size; t += TILE) { |
||||||
|
size_t copy_size = TILE; |
||||||
|
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; |
||||||
|
size_t offset = copy_size + t; |
||||||
|
|
||||||
|
#pragma omp parallel for |
||||||
|
for (size_t i = t; i < offset; i += SIMD_WIDTH) { |
||||||
|
float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i); |
||||||
|
if (loss_scale > 0) { |
||||||
|
float32x4_t loss_scale_vec = simd_set(loss_scale); |
||||||
|
grad_4 = vdivq_f32(grad_4, loss_scale_vec); |
||||||
|
} |
||||||
|
float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i); |
||||||
|
float32x4_t variance_4 = |
||||||
|
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i); |
||||||
|
float32x4_t param_4 = simd_load_offset(_params, param_dtype, i); |
||||||
|
if (_weight_decay > 0 && !_adamw_mode) { |
||||||
|
grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4); |
||||||
|
} |
||||||
|
momentum_4 = vmulq_f32(momentum_4, betta1_4); |
||||||
|
momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4); |
||||||
|
variance_4 = vmulq_f32(variance_4, betta2_4); |
||||||
|
grad_4 = vmulq_f32(grad_4, grad_4); |
||||||
|
variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4); |
||||||
|
grad_4 = vsqrtq_f32(variance_4); |
||||||
|
grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt); |
||||||
|
grad_4 = vdivq_f32(momentum_4, grad_4); |
||||||
|
if (_weight_decay > 0 && _adamw_mode) { |
||||||
|
param_4 = vfmaq_f32(param_4, param_4, weight_decay_4); |
||||||
|
} |
||||||
|
param_4 = vfmaq_f32(param_4, grad_4, step_size_4); |
||||||
|
simd_store_offset(_params, param_dtype, param_4, i); |
||||||
|
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i); |
||||||
|
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i); |
||||||
|
} |
||||||
|
} |
||||||
|
#endif |
||||||
|
if (_param_size > rounded_size) { |
||||||
|
for (size_t t = rounded_size; t < _param_size; t += TILE) { |
||||||
|
size_t copy_size = TILE; |
||||||
|
if ((t + TILE) > _param_size) copy_size = _param_size - t; |
||||||
|
size_t offset = copy_size + t; |
||||||
|
|
||||||
|
#pragma omp parallel for |
||||||
|
for (size_t k = t; k < offset; k++) { |
||||||
|
float grad = scalar_load_offset(grads, grad_dtype, k); |
||||||
|
if (loss_scale > 0) { |
||||||
|
grad /= loss_scale; |
||||||
|
} |
||||||
|
float param = scalar_load_offset(_params, param_dtype, k); |
||||||
|
float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k); |
||||||
|
float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k); |
||||||
|
if (_weight_decay > 0 && !_adamw_mode) { |
||||||
|
grad = param * _weight_decay + grad; |
||||||
|
} |
||||||
|
momentum = momentum * _betta1; |
||||||
|
momentum = grad * betta1_minus1 + momentum; |
||||||
|
|
||||||
|
variance = variance * _betta2; |
||||||
|
grad = grad * grad; |
||||||
|
variance = grad * betta2_minus1 + variance; |
||||||
|
|
||||||
|
grad = sqrt(variance); |
||||||
|
grad = grad * _bias_correction2 + _eps; |
||||||
|
grad = momentum / grad; |
||||||
|
if (_weight_decay > 0 && _adamw_mode) { |
||||||
|
param += w_decay * param; |
||||||
|
} |
||||||
|
param = grad * step_size + param; |
||||||
|
|
||||||
|
scalar_store_offset(_params, param_dtype, param, k); |
||||||
|
scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k); |
||||||
|
scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg, |
||||||
|
void *_exp_avg_sq, size_t _param_size, |
||||||
|
at::ScalarType param_dtype, |
||||||
|
at::ScalarType grad_dtype, |
||||||
|
at::ScalarType exp_avg_dtype, |
||||||
|
at::ScalarType exp_avg_sq_dtype, float loss_scale) { |
||||||
|
size_t rounded_size = 0; |
||||||
|
#if defined(__aarch64__) |
||||||
|
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); |
||||||
|
#endif |
||||||
|
|
||||||
|
float betta1_minus1 = 1 - _betta1; |
||||||
|
float betta2_minus1 = 1 - _betta2; |
||||||
|
float step_size = -1 * _alpha / _bias_correction1; |
||||||
|
float w_decay = -1 * _alpha * _weight_decay; |
||||||
|
|
||||||
|
#if defined(__aarch64__) |
||||||
|
float32x4_t betta1_4 = simd_set(_betta1); |
||||||
|
float32x4_t betta2_4 = simd_set(_betta2); |
||||||
|
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); |
||||||
|
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); |
||||||
|
float32x4_t bias2_sqrt = simd_set(_bias_correction2); |
||||||
|
float32x4_t eps_4 = simd_set(_eps); |
||||||
|
float32x4_t step_size_4 = simd_set(step_size); |
||||||
|
float32x4_t weight_decay_4; |
||||||
|
if (_weight_decay > 0) { |
||||||
|
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); |
||||||
|
} |
||||||
|
|
||||||
|
for (size_t t = 0; t < rounded_size; t += TILE) { |
||||||
|
size_t copy_size = TILE; |
||||||
|
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; |
||||||
|
size_t offset = copy_size + t; |
||||||
|
|
||||||
|
#pragma omp parallel for |
||||||
|
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) { |
||||||
|
float32x4_t grad_4[4]; |
||||||
|
float32x4_t momentum_4[4]; |
||||||
|
float32x4_t variance_4[4]; |
||||||
|
float32x4_t param_4[4]; |
||||||
|
#pragma unroll 4 |
||||||
|
for (int j = 0; j < 4; j++) { |
||||||
|
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j); |
||||||
|
if (loss_scale > 0) { |
||||||
|
float32x4_t loss_scale_vec = simd_set(loss_scale); |
||||||
|
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec); |
||||||
|
} |
||||||
|
momentum_4[j] = |
||||||
|
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j); |
||||||
|
variance_4[j] = |
||||||
|
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j); |
||||||
|
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j); |
||||||
|
if (_weight_decay > 0 && !_adamw_mode) { |
||||||
|
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4); |
||||||
|
} |
||||||
|
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4); |
||||||
|
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4); |
||||||
|
variance_4[j] = vmulq_f32(variance_4[j], betta2_4); |
||||||
|
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]); |
||||||
|
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4); |
||||||
|
grad_4[j] = vsqrtq_f32(variance_4[j]); |
||||||
|
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt); |
||||||
|
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]); |
||||||
|
if (_weight_decay > 0 && _adamw_mode) { |
||||||
|
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4); |
||||||
|
} |
||||||
|
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4); |
||||||
|
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j); |
||||||
|
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j], |
||||||
|
i + SIMD_WIDTH * j); |
||||||
|
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j], |
||||||
|
i + SIMD_WIDTH * j); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
#endif |
||||||
|
if (_param_size > rounded_size) { |
||||||
|
Step_1(scalar_seek_offset(_params, param_dtype, rounded_size), |
||||||
|
scalar_seek_offset(grads, grad_dtype, rounded_size), |
||||||
|
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size), |
||||||
|
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size), |
||||||
|
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype, |
||||||
|
exp_avg_sq_dtype, loss_scale); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg, |
||||||
|
void *_exp_avg_sq, size_t _param_size, |
||||||
|
at::ScalarType param_dtype, |
||||||
|
at::ScalarType grad_dtype, |
||||||
|
at::ScalarType exp_avg_dtype, |
||||||
|
at::ScalarType exp_avg_sq_dtype, float loss_scale) { |
||||||
|
size_t rounded_size = 0; |
||||||
|
#if defined(__aarch64__) |
||||||
|
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); |
||||||
|
#endif |
||||||
|
|
||||||
|
float betta1_minus1 = 1 - _betta1; |
||||||
|
float betta2_minus1 = 1 - _betta2; |
||||||
|
float step_size = -1 * _alpha / _bias_correction1; |
||||||
|
float w_decay = -1 * _alpha * _weight_decay; |
||||||
|
#if defined(__aarch64__) |
||||||
|
float32x4_t betta1_4 = simd_set(_betta1); |
||||||
|
float32x4_t betta2_4 = simd_set(_betta2); |
||||||
|
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); |
||||||
|
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); |
||||||
|
float32x4_t bias2_sqrt = simd_set(_bias_correction2); |
||||||
|
float32x4_t eps_4 = simd_set(_eps); |
||||||
|
float32x4_t step_size_4 = simd_set(step_size); |
||||||
|
float32x4_t weight_decay_4; |
||||||
|
if (_weight_decay > 0) { |
||||||
|
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); |
||||||
|
} |
||||||
|
|
||||||
|
for (size_t t = 0; t < rounded_size; t += TILE) { |
||||||
|
size_t copy_size = TILE; |
||||||
|
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; |
||||||
|
size_t offset = copy_size + t; |
||||||
|
|
||||||
|
#pragma omp parallel for |
||||||
|
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) { |
||||||
|
float32x4_t grad_4[8]; |
||||||
|
float32x4_t momentum_4[8]; |
||||||
|
float32x4_t variance_4[8]; |
||||||
|
float32x4_t param_4[8]; |
||||||
|
#pragma unroll 4 |
||||||
|
for (int j = 0; j < 8; j++) { |
||||||
|
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j); |
||||||
|
if (loss_scale > 0) { |
||||||
|
float32x4_t loss_scale_vec = simd_set(loss_scale); |
||||||
|
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec); |
||||||
|
} |
||||||
|
momentum_4[j] = |
||||||
|
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j); |
||||||
|
variance_4[j] = |
||||||
|
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j); |
||||||
|
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j); |
||||||
|
if (_weight_decay > 0 && !_adamw_mode) { |
||||||
|
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4); |
||||||
|
} |
||||||
|
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4); |
||||||
|
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4); |
||||||
|
variance_4[j] = vmulq_f32(variance_4[j], betta2_4); |
||||||
|
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]); |
||||||
|
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4); |
||||||
|
grad_4[j] = vsqrtq_f32(variance_4[j]); |
||||||
|
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt); |
||||||
|
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]); |
||||||
|
if (_weight_decay > 0 && _adamw_mode) { |
||||||
|
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4); |
||||||
|
} |
||||||
|
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4); |
||||||
|
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j); |
||||||
|
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j], |
||||||
|
i + SIMD_WIDTH * j); |
||||||
|
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j], |
||||||
|
i + SIMD_WIDTH * j); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
#endif |
||||||
|
if (_param_size > rounded_size) { |
||||||
|
Step_4(scalar_seek_offset(_params, param_dtype, rounded_size), |
||||||
|
scalar_seek_offset(grads, grad_dtype, rounded_size), |
||||||
|
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size), |
||||||
|
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size), |
||||||
|
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype, |
||||||
|
exp_avg_sq_dtype, loss_scale); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2, |
||||||
|
float epsilon, float weight_decay, |
||||||
|
bool bias_correction, torch::Tensor ¶ms, |
||||||
|
torch::Tensor &grads, torch::Tensor &exp_avg, |
||||||
|
torch::Tensor &exp_avg_sq, float loss_scale) { |
||||||
|
auto params_c = params.contiguous(); |
||||||
|
auto grads_c = grads.contiguous(); |
||||||
|
auto exp_avg_c = exp_avg.contiguous(); |
||||||
|
auto exp_avg_sq_c = exp_avg_sq.contiguous(); |
||||||
|
|
||||||
|
this->IncrementStep(step, beta1, beta2); |
||||||
|
this->update_state(lr, epsilon, weight_decay, bias_correction); |
||||||
|
this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(), |
||||||
|
exp_avg_sq_c.data_ptr(), params_c.numel(), |
||||||
|
params_c.scalar_type(), grads_c.scalar_type(), |
||||||
|
exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale); |
||||||
|
} |
||||||
|
|
||||||
|
namespace py = pybind11; |
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
||||||
|
py::class_<AdamOptimizer>(m, "CPUAdamOptimizer") |
||||||
|
.def(py::init<float, float, float, float, float, bool>()) |
||||||
|
.def("step", &AdamOptimizer::step); |
||||||
|
} |
@ -0,0 +1,201 @@ |
|||||||
|
#pragma once |
||||||
|
#include <ATen/ATen.h> |
||||||
|
#include <torch/extension.h> |
||||||
|
|
||||||
|
#include <cmath> |
||||||
|
|
||||||
|
#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) |
||||||
|
#define TILE (128 * 1024 * 1024) |
||||||
|
|
||||||
|
#if defined(__aarch64__) |
||||||
|
#include <arm_neon.h> |
||||||
|
#define SIMD_WIDTH 4 |
||||||
|
|
||||||
|
inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype, |
||||||
|
size_t offset) { |
||||||
|
switch (dtype) { |
||||||
|
case at::ScalarType::Float: { |
||||||
|
auto ptr_f = reinterpret_cast<const float32_t *>(ptr); |
||||||
|
return vld1q_f32(ptr_f + offset); |
||||||
|
} |
||||||
|
case at::ScalarType::Half: { |
||||||
|
auto ptr_h = reinterpret_cast<const float16_t *>(ptr); |
||||||
|
return vcvt_f32_f16(vld1_f16(ptr_h + offset)); |
||||||
|
} |
||||||
|
// case at::ScalarType::BFloat16: {
|
||||||
|
// auto ptr_b = reinterpret_cast<const bfloat16_t *>(ptr);
|
||||||
|
// return vcvt_f32_bf16(vld1_bf16(ptr_b + offset));
|
||||||
|
// }
|
||||||
|
default: |
||||||
|
AT_ERROR("Unsupported dtype"); |
||||||
|
break; |
||||||
|
} |
||||||
|
} |
||||||
|
inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) { |
||||||
|
return simd_load_offset(ptr, dtype, 0); |
||||||
|
} |
||||||
|
|
||||||
|
inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data, |
||||||
|
size_t offset) { |
||||||
|
switch (dtype) { |
||||||
|
case at::ScalarType::Float: { |
||||||
|
auto ptr_f = reinterpret_cast<float32_t *>(ptr); |
||||||
|
vst1q_f32(ptr_f + offset, data); |
||||||
|
break; |
||||||
|
} |
||||||
|
case at::ScalarType::Half: { |
||||||
|
auto ptr_h = reinterpret_cast<float16_t *>(ptr); |
||||||
|
vst1_f16(ptr_h + offset, vcvt_f16_f32(data)); |
||||||
|
break; |
||||||
|
} |
||||||
|
// case at::ScalarType::BFloat16: {
|
||||||
|
// auto ptr_b = reinterpret_cast<bfloat16_t *>(ptr);
|
||||||
|
// vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data));
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
default: |
||||||
|
AT_ERROR("Unsupported dtype"); |
||||||
|
break; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) { |
||||||
|
return simd_store_offset(ptr, dtype, data, 0); |
||||||
|
} |
||||||
|
|
||||||
|
inline float32x4_t simd_set(float value) { |
||||||
|
auto val = static_cast<float32_t>(value); |
||||||
|
return vdupq_n_f32(val); |
||||||
|
} |
||||||
|
|
||||||
|
#endif |
||||||
|
|
||||||
|
inline float scalar_load_offset(const void *ptr, at::ScalarType dtype, |
||||||
|
size_t offset) { |
||||||
|
switch (dtype) { |
||||||
|
case at::ScalarType::Float: |
||||||
|
return *(reinterpret_cast<const float *>(ptr) + offset); |
||||||
|
case at::ScalarType::Half: |
||||||
|
return static_cast<float>( |
||||||
|
*(reinterpret_cast<const at::Half *>(ptr) + offset)); |
||||||
|
// case at::ScalarType::BFloat16:
|
||||||
|
// return static_cast<float>(
|
||||||
|
// *(reinterpret_cast<const at::BFloat16 *>(ptr) + offset));
|
||||||
|
default: |
||||||
|
AT_ERROR("Unsupported dtype"); |
||||||
|
break; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data, |
||||||
|
size_t offset) { |
||||||
|
switch (dtype) { |
||||||
|
case at::ScalarType::Float: |
||||||
|
*(reinterpret_cast<float *>(ptr) + offset) = data; |
||||||
|
break; |
||||||
|
case at::ScalarType::Half: |
||||||
|
*(reinterpret_cast<at::Half *>(ptr) + offset) = data; |
||||||
|
break; |
||||||
|
// case at::ScalarType::BFloat16:
|
||||||
|
// *(reinterpret_cast<at::BFloat16 *>(ptr) + offset) = data;
|
||||||
|
break; |
||||||
|
default: |
||||||
|
AT_ERROR("Unsupported dtype"); |
||||||
|
break; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype, |
||||||
|
size_t offset) { |
||||||
|
switch (dtype) { |
||||||
|
case at::ScalarType::Float: |
||||||
|
return reinterpret_cast<float *>(ptr) + offset; |
||||||
|
case at::ScalarType::Half: |
||||||
|
return reinterpret_cast<at::Half *>(ptr) + offset; |
||||||
|
// case at::ScalarType::BFloat16:
|
||||||
|
// return reinterpret_cast<at::BFloat16 *>(ptr) + offset;
|
||||||
|
default: |
||||||
|
AT_ERROR("Unsupported dtype"); |
||||||
|
break; |
||||||
|
} |
||||||
|
} |
||||||
|
#define STEP(SPAN) \ |
||||||
|
void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
|
||||||
|
void *_exp_avg_sq, size_t _param_size, \
|
||||||
|
at::ScalarType param_dtype, at::ScalarType grad_dtype, \
|
||||||
|
at::ScalarType exp_avg_dtype, \
|
||||||
|
at::ScalarType exp_avg_sq_dtype, float loss_scale = -1); |
||||||
|
|
||||||
|
class AdamOptimizer { |
||||||
|
private: |
||||||
|
float _alpha; |
||||||
|
float _betta1; |
||||||
|
float _betta2; |
||||||
|
float _eps; |
||||||
|
float _weight_decay; |
||||||
|
|
||||||
|
float _betta1_t; |
||||||
|
float _betta2_t; |
||||||
|
size_t _step; |
||||||
|
|
||||||
|
float _bias_correction1; |
||||||
|
float _bias_correction2; |
||||||
|
|
||||||
|
bool _adamw_mode; |
||||||
|
|
||||||
|
public: |
||||||
|
AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, |
||||||
|
float eps = 1e-8, float weight_decay = 0, |
||||||
|
bool adamw_mode = true) |
||||||
|
: _alpha(alpha), |
||||||
|
_betta1(betta1), |
||||||
|
_betta2(betta2), |
||||||
|
_eps(eps), |
||||||
|
_weight_decay(weight_decay), |
||||||
|
_betta1_t(1.0), |
||||||
|
_betta2_t(1.0), |
||||||
|
_step(0), |
||||||
|
_adamw_mode(adamw_mode) {} |
||||||
|
~AdamOptimizer() {} |
||||||
|
|
||||||
|
STEP(1) |
||||||
|
STEP(4) |
||||||
|
STEP(8) |
||||||
|
inline void IncrementStep(size_t step, float beta1, float beta2) { |
||||||
|
if (beta1 != _betta1 || beta2 != _betta2) { |
||||||
|
_step = step; |
||||||
|
_betta1 = beta1; |
||||||
|
_betta2 = beta2; |
||||||
|
_betta1_t = std::pow(_betta1, step); |
||||||
|
_betta2_t = std::pow(_betta2, step); |
||||||
|
} else { |
||||||
|
_step++; |
||||||
|
if (_step != step) { |
||||||
|
_betta1_t = std::pow(_betta1, step); |
||||||
|
_betta2_t = std::pow(_betta2, step); |
||||||
|
_step = step; |
||||||
|
} else { |
||||||
|
_betta1_t *= _betta1; |
||||||
|
_betta2_t *= _betta2; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
inline void update_state(float lr, float epsilon, float weight_decay, |
||||||
|
bool bias_correction) { |
||||||
|
_alpha = lr; |
||||||
|
_eps = epsilon; |
||||||
|
_weight_decay = weight_decay; |
||||||
|
|
||||||
|
_bias_correction1 = 1.0f; |
||||||
|
_bias_correction2 = 1.0f; |
||||||
|
if (bias_correction == 1) { |
||||||
|
_bias_correction1 = 1 - _betta1_t; |
||||||
|
_bias_correction2 = 1 / sqrt(1 - _betta2_t); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
void step(size_t step, float lr, float beta1, float beta2, float epsilon, |
||||||
|
float weight_decay, bool bias_correction, torch::Tensor ¶ms, |
||||||
|
torch::Tensor &grads, torch::Tensor &exp_avg, |
||||||
|
torch::Tensor &exp_avg_sq, float loss_scale); |
||||||
|
}; |
@ -1,56 +0,0 @@ |
|||||||
#!/usr/bin/env python |
|
||||||
# -*- encoding: utf-8 -*- |
|
||||||
|
|
||||||
from typing import Optional |
|
||||||
|
|
||||||
import torch |
|
||||||
import torch.distributed as dist |
|
||||||
|
|
||||||
|
|
||||||
def set_to_cuda(models): |
|
||||||
"""Send model to gpu. |
|
||||||
|
|
||||||
:param models: nn.module or a list of module |
|
||||||
""" |
|
||||||
if isinstance(models, list) and len(models) > 1: |
|
||||||
ret = [] |
|
||||||
for model in models: |
|
||||||
ret.append(model.to(get_current_device())) |
|
||||||
return ret |
|
||||||
elif isinstance(models, list): |
|
||||||
return models[0].to(get_current_device()) |
|
||||||
else: |
|
||||||
return models.to(get_current_device()) |
|
||||||
|
|
||||||
|
|
||||||
def get_current_device() -> torch.device: |
|
||||||
""" |
|
||||||
Returns currently selected device (gpu/cpu). |
|
||||||
If cuda available, return gpu, otherwise return cpu. |
|
||||||
""" |
|
||||||
if torch.cuda.is_available(): |
|
||||||
return torch.device(f"cuda:{torch.cuda.current_device()}") |
|
||||||
else: |
|
||||||
return torch.device("cpu") |
|
||||||
|
|
||||||
|
|
||||||
def synchronize(): |
|
||||||
"""Similar to cuda.synchronize(). |
|
||||||
Waits for all kernels in all streams on a CUDA device to complete. |
|
||||||
""" |
|
||||||
if torch.cuda.is_available(): |
|
||||||
torch.cuda.synchronize() |
|
||||||
|
|
||||||
|
|
||||||
def empty_cache(): |
|
||||||
"""Similar to cuda.empty_cache() |
|
||||||
Releases all unoccupied cached memory currently held by the caching allocator. |
|
||||||
""" |
|
||||||
if torch.cuda.is_available(): |
|
||||||
torch.cuda.empty_cache() |
|
||||||
|
|
||||||
|
|
||||||
def set_device(index: Optional[int] = None) -> None: |
|
||||||
if index is None: |
|
||||||
index = dist.get_rank() % torch.cuda.device_count() |
|
||||||
torch.cuda.set_device(index) |
|
@ -0,0 +1,207 @@ |
|||||||
|
#!/usr/bin/env python |
||||||
|
# -*- encoding: utf-8 -*- |
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.distributed as dist |
||||||
|
|
||||||
|
IS_NPU_AVAILABLE: bool = False |
||||||
|
try: |
||||||
|
import torch_npu # noqa |
||||||
|
|
||||||
|
IS_NPU_AVAILABLE = torch.npu.is_available() |
||||||
|
except ImportError: |
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
def set_to_cuda(models): |
||||||
|
"""Send model to gpu. |
||||||
|
|
||||||
|
:param models: nn.module or a list of module |
||||||
|
""" |
||||||
|
if isinstance(models, list) and len(models) > 1: |
||||||
|
ret = [] |
||||||
|
for model in models: |
||||||
|
ret.append(model.to(get_current_device())) |
||||||
|
return ret |
||||||
|
elif isinstance(models, list): |
||||||
|
return models[0].to(get_current_device()) |
||||||
|
else: |
||||||
|
return models.to(get_current_device()) |
||||||
|
|
||||||
|
|
||||||
|
def get_current_device() -> torch.device: |
||||||
|
""" |
||||||
|
Returns currently selected device (gpu/cpu). |
||||||
|
If cuda available, return gpu, otherwise return cpu. |
||||||
|
""" |
||||||
|
if torch.cuda.is_available(): |
||||||
|
return torch.device(f"cuda:{torch.cuda.current_device()}") |
||||||
|
elif IS_NPU_AVAILABLE: |
||||||
|
return torch.device(f"npu:{torch.npu.current_device()}") |
||||||
|
else: |
||||||
|
return torch.device("cpu") |
||||||
|
|
||||||
|
|
||||||
|
def _dispatch_device_func(fn_name: str, *args, **kwargs): |
||||||
|
if torch.cuda.is_available(): |
||||||
|
return getattr(torch.cuda, fn_name)(*args, **kwargs) |
||||||
|
elif IS_NPU_AVAILABLE: |
||||||
|
return getattr(torch.npu, fn_name)(*args, **kwargs) |
||||||
|
else: |
||||||
|
raise RuntimeError("No device available") |
||||||
|
|
||||||
|
|
||||||
|
# device semantics |
||||||
|
|
||||||
|
|
||||||
|
def can_device_access_peer(device, peer_device) -> bool: |
||||||
|
return _dispatch_device_func("can_device_access_peer", device, peer_device) |
||||||
|
|
||||||
|
|
||||||
|
def current_device() -> int: |
||||||
|
return _dispatch_device_func("current_device") |
||||||
|
|
||||||
|
|
||||||
|
def current_stream(device=None): |
||||||
|
return _dispatch_device_func("current_stream", device) |
||||||
|
|
||||||
|
|
||||||
|
def default_stream(device=None): |
||||||
|
return _dispatch_device_func("default_stream", device) |
||||||
|
|
||||||
|
|
||||||
|
def device_count() -> int: |
||||||
|
return _dispatch_device_func("device_count") |
||||||
|
|
||||||
|
|
||||||
|
def get_device_capability(device=None) -> Tuple[int, int]: |
||||||
|
return _dispatch_device_func("get_device_capability", device) |
||||||
|
|
||||||
|
|
||||||
|
def get_device_name(device=None) -> str: |
||||||
|
return _dispatch_device_func("get_device_name", device) |
||||||
|
|
||||||
|
|
||||||
|
def get_device_properties(device): |
||||||
|
return _dispatch_device_func("get_device_properties", device) |
||||||
|
|
||||||
|
|
||||||
|
def set_device(index: Optional[int] = None) -> None: |
||||||
|
if index is None: |
||||||
|
index = dist.get_rank() % device_count() |
||||||
|
_dispatch_device_func("set_device", index) |
||||||
|
|
||||||
|
|
||||||
|
def set_stream(stream_): |
||||||
|
return _dispatch_device_func("set_stream", stream_) |
||||||
|
|
||||||
|
|
||||||
|
def stream(stream_): |
||||||
|
return _dispatch_device_func("stream", stream_) |
||||||
|
|
||||||
|
|
||||||
|
def synchronize(): |
||||||
|
return _dispatch_device_func("synchronize") |
||||||
|
|
||||||
|
|
||||||
|
def utilization(device=None) -> int: |
||||||
|
return _dispatch_device_func("utilization", device) |
||||||
|
|
||||||
|
|
||||||
|
# random number generator |
||||||
|
|
||||||
|
|
||||||
|
def get_rng_state(device="cuda") -> torch.Tensor: |
||||||
|
return _dispatch_device_func("get_rng_state", device) |
||||||
|
|
||||||
|
|
||||||
|
def get_rng_state_all() -> List[torch.Tensor]: |
||||||
|
return _dispatch_device_func("get_rng_state_all") |
||||||
|
|
||||||
|
|
||||||
|
def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None: |
||||||
|
return _dispatch_device_func("set_rng_state", new_state, device) |
||||||
|
|
||||||
|
|
||||||
|
def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None: |
||||||
|
return _dispatch_device_func("set_rng_state_all", new_states) |
||||||
|
|
||||||
|
|
||||||
|
def manual_seed(seed: int) -> None: |
||||||
|
return _dispatch_device_func("manual_seed", seed) |
||||||
|
|
||||||
|
|
||||||
|
def manual_seed_all(seed: int) -> None: |
||||||
|
return _dispatch_device_func("manual_seed_all", seed) |
||||||
|
|
||||||
|
|
||||||
|
def seed() -> None: |
||||||
|
return _dispatch_device_func("seed") |
||||||
|
|
||||||
|
|
||||||
|
def seed_all() -> None: |
||||||
|
return _dispatch_device_func("seed_all") |
||||||
|
|
||||||
|
|
||||||
|
def initial_seed() -> int: |
||||||
|
return _dispatch_device_func("initial_seed") |
||||||
|
|
||||||
|
|
||||||
|
# streams and events |
||||||
|
|
||||||
|
|
||||||
|
def Stream(device=None, priority=0, **kwargs): |
||||||
|
return _dispatch_device_func("Stream", device, priority, **kwargs) |
||||||
|
|
||||||
|
|
||||||
|
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): |
||||||
|
return _dispatch_device_func("Event", enable_timing, blocking, interprocess) |
||||||
|
|
||||||
|
|
||||||
|
# memory management |
||||||
|
|
||||||
|
|
||||||
|
def empty_cache() -> None: |
||||||
|
return _dispatch_device_func("empty_cache") |
||||||
|
|
||||||
|
|
||||||
|
def memory_stats(device=None) -> Dict[str, Any]: |
||||||
|
return _dispatch_device_func("memory_stats", device) |
||||||
|
|
||||||
|
|
||||||
|
def memory_summary(device=None, abbreviated=False) -> str: |
||||||
|
return _dispatch_device_func("memory_summary", device, abbreviated) |
||||||
|
|
||||||
|
|
||||||
|
def memory_snapshot(): |
||||||
|
return _dispatch_device_func("memory_snapshot") |
||||||
|
|
||||||
|
|
||||||
|
def memory_allocated(device=None) -> int: |
||||||
|
return _dispatch_device_func("memory_allocated", device) |
||||||
|
|
||||||
|
|
||||||
|
def max_memory_allocated(device=None) -> int: |
||||||
|
return _dispatch_device_func("max_memory_allocated", device) |
||||||
|
|
||||||
|
|
||||||
|
def reset_max_memory_allocated(device=None) -> None: |
||||||
|
return _dispatch_device_func("reset_max_memory_allocated", device) |
||||||
|
|
||||||
|
|
||||||
|
def memory_reserved(device=None) -> int: |
||||||
|
return _dispatch_device_func("memory_reserved", device) |
||||||
|
|
||||||
|
|
||||||
|
def max_memory_reserved(device=None) -> int: |
||||||
|
return _dispatch_device_func("max_memory_reserved", device) |
||||||
|
|
||||||
|
|
||||||
|
def set_per_process_memory_fraction(fraction: float, device=None) -> None: |
||||||
|
return _dispatch_device_func("set_per_process_memory_fraction", fraction, device) |
||||||
|
|
||||||
|
|
||||||
|
def reset_peak_memory_stats(device=None) -> None: |
||||||
|
return _dispatch_device_func("reset_peak_memory_stats", device) |
@ -0,0 +1,34 @@ |
|||||||
|
from .builder import Builder |
||||||
|
|
||||||
|
|
||||||
|
class ArmCPUAdamBuilder(Builder): |
||||||
|
NAME = "arm_cpu_adam" |
||||||
|
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" |
||||||
|
ext_type = "cpu" |
||||||
|
|
||||||
|
def __init__(self): |
||||||
|
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) |
||||||
|
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] |
||||||
|
|
||||||
|
# necessary 4 functions |
||||||
|
def sources_files(self): |
||||||
|
ret = [ |
||||||
|
self.csrc_abs_path("cpu_adam_arm.cpp"), |
||||||
|
] |
||||||
|
return ret |
||||||
|
|
||||||
|
def include_dirs(self): |
||||||
|
return [self.csrc_abs_path("includes")] |
||||||
|
|
||||||
|
def cxx_flags(self): |
||||||
|
extra_cxx_flags = [ |
||||||
|
"-std=c++14", |
||||||
|
"-std=c++17", |
||||||
|
"-g", |
||||||
|
"-Wno-reorder", |
||||||
|
"-fopenmp", |
||||||
|
] |
||||||
|
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags |
||||||
|
|
||||||
|
def nvcc_flags(self): |
||||||
|
return [] |
Loading…
Reference in new issue