[optimizer] add div_scale for optimizers (#2117)

* [optimizer] add div_scale for optimizers

* [zero] use div_scale in zero optimizer

* fix testing error
pull/2123/head
HELSON 2022-12-12 17:58:57 +08:00 committed by GitHub
parent e5aa8333e4
commit e7d3afc9cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 41 additions and 32 deletions

View File

@ -11,7 +11,7 @@ def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List
...
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float) -> None:
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float, div_scale: float) -> None:
...

View File

@ -17,8 +17,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction,
const float weight_decay);
const int bias_correction, const float weight_decay,
const float div_scale);
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
@ -46,4 +46,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Computes and apply update for LAMB optimizer");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
}
}

View File

@ -28,7 +28,7 @@ struct AdamFunctor {
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta1_correction,
const float beta2_correction, const float epsilon, const float lr,
adamMode_t mode, const float decay) {
adamMode_t mode, const float decay, const float div_scale) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
@ -79,6 +79,8 @@ struct AdamFunctor {
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (div_scale > 0) r_g[ii] /= div_scale;
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
@ -116,8 +118,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction,
const float weight_decay) {
const int bias_correction, const float weight_decay,
const float div_scale) {
using namespace at;
// Handle bias correction mode
@ -133,7 +135,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);)
lr, (adamMode_t)mode, weight_decay, div_scale);)
AT_CUDA_CHECK(cudaGetLastError());
}
}

View File

@ -117,7 +117,7 @@ class CPUAdam(NVMeOptimizer):
data.addcdiv_(exp_avg, denom, value=-step_size)
@torch.no_grad()
def step(self, closure=None):
def step(self, closure=None, div_scale: float = -1):
loss = None
if closure is not None:
with torch.enable_grad():
@ -152,9 +152,10 @@ class CPUAdam(NVMeOptimizer):
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], -1)
state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda':
assert div_scale == -1, "div_scale should remain default"
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"

View File

@ -81,7 +81,7 @@ class FusedAdam(torch.optim.Optimizer):
else:
super(FusedAdam, self).zero_grad()
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale: float = -1):
"""Performs a single optimization step.
Arguments:
@ -137,6 +137,6 @@ class FusedAdam(torch.optim.Optimizer):
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction,
group['weight_decay'])
group['weight_decay'], div_scale)
return loss

View File

@ -89,7 +89,7 @@ class HybridAdam(NVMeOptimizer):
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@torch.no_grad()
def step(self, closure=None):
def step(self, closure=None, div_scale: float = -1):
loss = None
if closure is not None:
with torch.enable_grad():
@ -126,7 +126,7 @@ class HybridAdam(NVMeOptimizer):
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], -1)
state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda':
@ -146,6 +146,6 @@ class HybridAdam(NVMeOptimizer):
bias_correction = 1 if group['bias_correction'] else 0
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode,
bias_correction, group['weight_decay'])
bias_correction, group['weight_decay'], div_scale)
self._post_step()
return loss

View File

@ -10,10 +10,12 @@ from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class OptimState(Enum):
SCALED = 0
@ -62,6 +64,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
**defaults: Any):
super().__init__(optim)
assert isinstance(module, ZeroDDP)
assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list"
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
@ -162,21 +165,24 @@ class ZeroOptimizer(ColossalaiOptimizer):
global_norm = math.sqrt(norm_sqr)
return global_norm
def _unscale_and_clip_grads(self):
assert self.optim_state == OptimState.SCALED
def _get_combined_scale(self):
loss_scale = 1
combined_scale = self.loss_scale
if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
combined_scale = loss_scale
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / self.loss_scale) + 1e-6) / self.max_norm
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
if clip > 1:
combined_scale = clip * self.loss_scale
combined_scale = clip * loss_scale
for group in self.optim.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.data.div_(combined_scale)
self.optim_state = OptimState.UNSCALED
if combined_scale == 1:
return -1
else:
return combined_scale
@property
def loss_scale(self):
@ -199,12 +205,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
self._update_fp16_params()
return
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_and_clip_grads()
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)
ret = self.optim.step(*args, **kwargs)
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states()
self.zero_grad()
self._update_fp16_params()

View File

@ -71,7 +71,7 @@ def test_adam(adamw, step, p_dtype, g_dtype):
weight_decay = 0
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
True, weight_decay)
True, weight_decay, -1)
torch_adam_update(
step,