mirror of https://github.com/hpcaitech/ColossalAI
[optimizer] add div_scale for optimizers (#2117)
* [optimizer] add div_scale for optimizers * [zero] use div_scale in zero optimizer * fix testing errorpull/2123/head
parent
e5aa8333e4
commit
e7d3afc9cc
|
@ -11,7 +11,7 @@ def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List
|
|||
...
|
||||
|
||||
|
||||
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float) -> None:
|
||||
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float, div_scale: float) -> None:
|
||||
...
|
||||
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
|
|||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int mode,
|
||||
const int bias_correction,
|
||||
const float weight_decay);
|
||||
const int bias_correction, const float weight_decay,
|
||||
const float div_scale);
|
||||
|
||||
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
|
@ -46,4 +46,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
"Computes and apply update for LAMB optimizer");
|
||||
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
|
||||
"Computes L2 norm for a list of contiguous tensors");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ struct AdamFunctor {
|
|||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
|
||||
const float beta1, const float beta2, const float beta1_correction,
|
||||
const float beta2_correction, const float epsilon, const float lr,
|
||||
adamMode_t mode, const float decay) {
|
||||
adamMode_t mode, const float decay, const float div_scale) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
@ -79,6 +79,8 @@ struct AdamFunctor {
|
|||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
if (div_scale > 0) r_g[ii] /= div_scale;
|
||||
|
||||
if (mode == ADAM_MODE_0) { // L2
|
||||
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
|
@ -116,8 +118,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
|
|||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int mode,
|
||||
const int bias_correction,
|
||||
const float weight_decay) {
|
||||
const int bias_correction, const float weight_decay,
|
||||
const float div_scale) {
|
||||
using namespace at;
|
||||
|
||||
// Handle bias correction mode
|
||||
|
@ -133,7 +135,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
|
|||
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
|
||||
beta2, bias_correction1, bias_correction2, epsilon,
|
||||
lr, (adamMode_t)mode, weight_decay);)
|
||||
lr, (adamMode_t)mode, weight_decay, div_scale);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -117,7 +117,7 @@ class CPUAdam(NVMeOptimizer):
|
|||
data.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
def step(self, closure=None, div_scale: float = -1):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
|
@ -152,9 +152,10 @@ class CPUAdam(NVMeOptimizer):
|
|||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
|
||||
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
|
||||
state['exp_avg_sq'], -1)
|
||||
state['exp_avg_sq'], div_scale)
|
||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
elif target_device.type == 'cuda':
|
||||
assert div_scale == -1, "div_scale should remain default"
|
||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
else:
|
||||
super(FusedAdam, self).zero_grad()
|
||||
|
||||
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
|
||||
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale: float = -1):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
|
@ -137,6 +137,6 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
|
||||
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
|
||||
beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction,
|
||||
group['weight_decay'])
|
||||
group['weight_decay'], div_scale)
|
||||
|
||||
return loss
|
||||
|
|
|
@ -89,7 +89,7 @@ class HybridAdam(NVMeOptimizer):
|
|||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
def step(self, closure=None, div_scale: float = -1):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
|
@ -126,7 +126,7 @@ class HybridAdam(NVMeOptimizer):
|
|||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
|
||||
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
|
||||
state['exp_avg_sq'], -1)
|
||||
state['exp_avg_sq'], div_scale)
|
||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
|
||||
elif target_device.type == 'cuda':
|
||||
|
@ -146,6 +146,6 @@ class HybridAdam(NVMeOptimizer):
|
|||
bias_correction = 1 if group['bias_correction'] else 0
|
||||
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
|
||||
group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode,
|
||||
bias_correction, group['weight_decay'])
|
||||
bias_correction, group['weight_decay'], div_scale)
|
||||
self._post_step()
|
||||
return loss
|
||||
|
|
|
@ -10,10 +10,12 @@ from torch.optim import Optimizer
|
|||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||
from colossalai.utils import disposable, get_current_device
|
||||
|
||||
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
SCALED = 0
|
||||
|
@ -62,6 +64,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
**defaults: Any):
|
||||
super().__init__(optim)
|
||||
assert isinstance(module, ZeroDDP)
|
||||
assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list"
|
||||
self.module = module
|
||||
self.gemini_manager = module.gemini_manager
|
||||
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
||||
|
@ -162,21 +165,24 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
global_norm = math.sqrt(norm_sqr)
|
||||
return global_norm
|
||||
|
||||
def _unscale_and_clip_grads(self):
|
||||
assert self.optim_state == OptimState.SCALED
|
||||
def _get_combined_scale(self):
|
||||
loss_scale = 1
|
||||
|
||||
combined_scale = self.loss_scale
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
loss_scale = self.loss_scale
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
|
||||
combined_scale = loss_scale
|
||||
if self.clipping_flag:
|
||||
total_norm = self._calc_global_norm()
|
||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self.max_norm
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
|
||||
if clip > 1:
|
||||
combined_scale = clip * self.loss_scale
|
||||
combined_scale = clip * loss_scale
|
||||
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is not None:
|
||||
p.grad.data.div_(combined_scale)
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
if combined_scale == 1:
|
||||
return -1
|
||||
else:
|
||||
return combined_scale
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
|
@ -199,12 +205,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
self._update_fp16_params()
|
||||
return
|
||||
|
||||
# unscale grads if scaled
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_and_clip_grads()
|
||||
# get combined scale. combined scale = loss scale * clipping norm
|
||||
# so that gradient = gradient / combined scale
|
||||
combined_scale = self._get_combined_scale()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
||||
self._register_states()
|
||||
self.zero_grad()
|
||||
self._update_fp16_params()
|
||||
|
|
|
@ -71,7 +71,7 @@ def test_adam(adamw, step, p_dtype, g_dtype):
|
|||
weight_decay = 0
|
||||
|
||||
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
|
||||
True, weight_decay)
|
||||
True, weight_decay, -1)
|
||||
|
||||
torch_adam_update(
|
||||
step,
|
||||
|
|
Loading…
Reference in New Issue