mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix CPUAdam kernel nullptr (#1410)
parent
1e5eb0874c
commit
12b4887097
|
@ -24,15 +24,12 @@ SOFTWARE
|
|||
#include <math.h>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
|
||||
|
@ -310,35 +307,6 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
|
|||
grad_half_precision, loss_scale);
|
||||
}
|
||||
|
||||
int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
|
||||
float betta1 = 0.9, float betta2 = 0.999,
|
||||
float eps = 1e-8, float weight_decay = 0,
|
||||
bool adamw_mode = true, bool should_log = false) {
|
||||
auto opt = std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps,
|
||||
weight_decay, adamw_mode);
|
||||
|
||||
s_optimizers[optimizer_id] = opt;
|
||||
|
||||
if (should_log) {
|
||||
std::string avx_type = "";
|
||||
#if defined(__AVX512__)
|
||||
avx_type = "AVX512";
|
||||
#else
|
||||
#if defined(__AVX256__) or defined(__AVX2__)
|
||||
avx_type = "AVX2";
|
||||
#else
|
||||
avx_type = "scalar";
|
||||
#endif
|
||||
#endif
|
||||
printf("Adam Optimizer #%d is created with %s arithmetic capability.\n",
|
||||
optimizer_id, avx_type.c_str());
|
||||
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
|
||||
alpha, betta1, betta2, weight_decay, (int)adamw_mode);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
||||
float *_exp_avg_sq, size_t _param_size,
|
||||
bool param_half_precision, bool grad_half_precision,
|
||||
|
@ -460,11 +428,11 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
|||
grad_half_precision, loss_scale);
|
||||
}
|
||||
|
||||
int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2,
|
||||
float epsilon, float weight_decay, bool bias_correction,
|
||||
torch::Tensor ¶ms, torch::Tensor &grads,
|
||||
torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq,
|
||||
float loss_scale) {
|
||||
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
|
||||
float epsilon, float weight_decay,
|
||||
bool bias_correction, torch::Tensor ¶ms,
|
||||
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||
torch::Tensor &exp_avg_sq, float loss_scale) {
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
|
@ -474,24 +442,18 @@ int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2,
|
|||
float *grads_ptr = (float *)grads_c.data_ptr();
|
||||
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
|
||||
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
|
||||
std::shared_ptr<Adam_Optimizer> opt =
|
||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
|
||||
params_c.numel(), (params.options().dtype() == at::kHalf),
|
||||
(grads.options().dtype() == at::kHalf), loss_scale);
|
||||
|
||||
return 0;
|
||||
this->IncrementStep(step, beta1, beta2);
|
||||
this->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
|
||||
params_c.numel(), (params.options().dtype() == at::kHalf),
|
||||
(grads.options().dtype() == at::kHalf), loss_scale);
|
||||
}
|
||||
|
||||
int destroy_adam_optimizer(int optimizer_id) {
|
||||
s_optimizers.erase(optimizer_id);
|
||||
return 0;
|
||||
}
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("adam_update", &adam_step, "CPU Adam update (C++)");
|
||||
m.def("create_adam", &create_adam_optimizer, "CPU Adam (C++)");
|
||||
m.def("destroy_adam", &destroy_adam_optimizer, "CPU Adam destroy (C++)");
|
||||
py::class_<Adam_Optimizer>(m, "CPUAdamOptimizer")
|
||||
.def(py::init<float, float, float, float, float, bool>())
|
||||
.def("step", &Adam_Optimizer::step);
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ SOFTWARE
|
|||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <torch/extension.h>
|
||||
#if (__x86_64__ || __i386__)
|
||||
#include <cpuid.h>
|
||||
#include <x86intrin.h>
|
||||
|
@ -141,6 +141,11 @@ class Adam_Optimizer {
|
|||
}
|
||||
}
|
||||
|
||||
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
|
||||
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
|
||||
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||
torch::Tensor &exp_avg_sq, float loss_scale);
|
||||
|
||||
private:
|
||||
float _alpha;
|
||||
float _betta1;
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from .utils import CPU_ADAM_CNT
|
||||
from .colossalai_optimizer import ColossalaiOptimizer
|
||||
from .fused_adam import FusedAdam
|
||||
from .fused_lamb import FusedLAMB
|
||||
|
@ -8,6 +7,4 @@ from .lars import Lars
|
|||
from .cpu_adam import CPUAdam
|
||||
from .hybrid_adam import HybridAdam
|
||||
|
||||
__all__ = [
|
||||
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT'
|
||||
]
|
||||
__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam']
|
||||
|
|
|
@ -2,7 +2,6 @@ import math
|
|||
import torch
|
||||
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
from colossalai.nn.optimizer import CPU_ADAM_CNT
|
||||
from .nvme_optimizer import NVMeOptimizer
|
||||
from typing import Optional
|
||||
|
||||
|
@ -69,25 +68,17 @@ class CPUAdam(NVMeOptimizer):
|
|||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
adamw_mode=True,
|
||||
simd_log=False,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
nvme_offload_dir: Optional[str] = None):
|
||||
|
||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.opt_id = CPU_ADAM_CNT()
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
import cpu_adam
|
||||
except ImportError:
|
||||
raise ImportError('Please install colossalai from source code to use CPUAdam')
|
||||
self.cpu_adam_op = cpu_adam
|
||||
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
|
||||
|
||||
def __del__(self):
|
||||
super().__del__()
|
||||
if getattr(self, 'cpu_adam_op', None):
|
||||
self.cpu_adam_op.destroy_adam(self.opt_id)
|
||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
def torch_adam_update(self,
|
||||
data,
|
||||
|
@ -156,9 +147,9 @@ class CPUAdam(NVMeOptimizer):
|
|||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'], -1)
|
||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
|
||||
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
|
||||
state['exp_avg_sq'], -1)
|
||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
elif target_device.type == 'cuda':
|
||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
|
|
|
@ -2,7 +2,6 @@ import torch
|
|||
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
from colossalai.nn.optimizer import CPU_ADAM_CNT
|
||||
from typing import Optional
|
||||
from .nvme_optimizer import NVMeOptimizer
|
||||
|
||||
|
@ -68,13 +67,11 @@ class HybridAdam(NVMeOptimizer):
|
|||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
adamw_mode=True,
|
||||
simd_log=False,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
nvme_offload_dir: Optional[str] = None):
|
||||
|
||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.opt_id = CPU_ADAM_CNT()
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
import cpu_adam
|
||||
|
@ -82,17 +79,11 @@ class HybridAdam(NVMeOptimizer):
|
|||
except ImportError:
|
||||
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
||||
|
||||
self.cpu_adam_op = cpu_adam
|
||||
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
|
||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
self.gpu_adam_op = colossal_C.multi_tensor_adam
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
def __del__(self):
|
||||
super().__del__()
|
||||
if getattr(self, 'cpu_adam_op', None):
|
||||
self.cpu_adam_op.destroy_adam(self.opt_id)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
|
@ -129,9 +120,9 @@ class HybridAdam(NVMeOptimizer):
|
|||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'], -1)
|
||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
|
||||
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
|
||||
state['exp_avg_sq'], -1)
|
||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
|
||||
elif target_device.type == 'cuda':
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
class CpuAdamCounter(object):
|
||||
"""Used to record the total number of CPU Adam.
|
||||
We must use it to avoid hybrid cpu adam and cpu adam using the same id.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.number = 0
|
||||
|
||||
def __call__(self):
|
||||
self.number += 1
|
||||
return self.number - 1
|
||||
|
||||
|
||||
CPU_ADAM_CNT = CpuAdamCounter()
|
|
@ -54,7 +54,7 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
|||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
weight_decay = 0
|
||||
|
||||
|
||||
for i in range(1024):
|
||||
p_data = torch.rand(64, dtype=p_dtype)
|
||||
p_data_copy = p_data.clone().float()
|
||||
|
@ -67,13 +67,11 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
|||
|
||||
try:
|
||||
import cpu_adam
|
||||
cpu_adam_op = cpu_adam
|
||||
cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
except:
|
||||
raise ImportError("Import cpu adam error, please install colossal from source code")
|
||||
|
||||
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
|
||||
cpu_adam_op.adam_update(
|
||||
0,
|
||||
cpu_adam_op.step(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
|
@ -8,9 +8,11 @@ from colossalai.testing import parameterize
|
|||
|
||||
|
||||
class FC(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(nn.Linear(64, 64))
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x)
|
||||
|
||||
|
@ -37,7 +39,7 @@ def test_adam(adamw, p_dtype, g_dtype):
|
|||
|
||||
for d, l in zip(data, label):
|
||||
y = model(d)
|
||||
loss = ((l - y) ** 2).sum()
|
||||
loss = ((l - y)**2).sum()
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
if p_dtype != g_dtype:
|
||||
|
@ -47,13 +49,13 @@ def test_adam(adamw, p_dtype, g_dtype):
|
|||
|
||||
for d, l in zip(data_copy, label):
|
||||
y = model_copy(d)
|
||||
loss = ((l - y) ** 2).sum()
|
||||
loss = ((l - y)**2).sum()
|
||||
torch_optim.zero_grad()
|
||||
loss.backward()
|
||||
torch_optim.step()
|
||||
|
||||
assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
|
||||
|
||||
|
||||
for i in range(len(optim.param_groups[0]['params'])):
|
||||
if torch.isnan(optim.param_groups[0]['params'][i]).any() \
|
||||
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
|
|
@ -7,6 +7,7 @@ import math
|
|||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
|
||||
def torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
|
@ -51,7 +52,7 @@ def test_adam(adamw, step, p_dtype, g_dtype):
|
|||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
except:
|
||||
raise ImportError("No colossal_C kernel installed.")
|
||||
|
||||
|
||||
count = 0
|
||||
|
||||
for i in range(1024):
|
||||
|
@ -69,26 +70,26 @@ def test_adam(adamw, step, p_dtype, g_dtype):
|
|||
eps = 1e-8
|
||||
weight_decay = 0
|
||||
|
||||
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]],
|
||||
lr, beta1, beta2, eps, step, adamw,
|
||||
True, weight_decay)
|
||||
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
|
||||
True, weight_decay)
|
||||
|
||||
torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
p_copy, # fp32 data
|
||||
g_copy, # fp32 grad
|
||||
m_copy,
|
||||
v_copy,
|
||||
adamw,
|
||||
)
|
||||
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
p_copy, # fp32 data
|
||||
g_copy, # fp32 grad
|
||||
m_copy,
|
||||
v_copy,
|
||||
adamw,
|
||||
)
|
||||
|
||||
if torch.isnan(p).any() or torch.isnan(p_copy).any():
|
||||
count += 1
|
||||
continue
|
||||
assert count < 200, "too many nans"
|
||||
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
||||
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5,
|
||||
1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
|
@ -17,7 +17,7 @@ def test_adam(adamw, device, p_dtype, g_dtype):
|
|||
rng_state = torch.get_rng_state()
|
||||
p = nn.Parameter(torch.rand(64).to(device, p_dtype))
|
||||
torch.set_rng_state(rng_state)
|
||||
p_copy = nn.Parameter(torch.rand(64).to(device).float())
|
||||
p_copy = nn.Parameter(torch.rand(64).to(device).float())
|
||||
|
||||
if adamw:
|
||||
optim = HybridAdam([p], lr=1e-3, adamw_mode=True)
|
||||
|
@ -38,4 +38,4 @@ def test_adam(adamw, device, p_dtype, g_dtype):
|
|||
if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any():
|
||||
continue
|
||||
assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \
|
||||
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
||||
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
Loading…
Reference in New Issue