[optimizer] add div_scale for optimizers (#2117)

* [optimizer] add div_scale for optimizers

* [zero] use div_scale in zero optimizer

* fix testing error
pull/2123/head
HELSON 2022-12-12 17:58:57 +08:00 committed by GitHub
parent e5aa8333e4
commit e7d3afc9cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 41 additions and 32 deletions

View File

@ -11,7 +11,7 @@ def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List
... ...
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float) -> None: def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float, div_scale: float) -> None:
... ...

View File

@ -17,8 +17,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const float lr, const float beta1, const float lr, const float beta1,
const float beta2, const float epsilon, const float beta2, const float epsilon,
const int step, const int mode, const int step, const int mode,
const int bias_correction, const int bias_correction, const float weight_decay,
const float weight_decay); const float div_scale);
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
@ -46,4 +46,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Computes and apply update for LAMB optimizer"); "Computes and apply update for LAMB optimizer");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors"); "Computes L2 norm for a list of contiguous tensors");
} }

View File

@ -28,7 +28,7 @@ struct AdamFunctor {
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta1_correction, const float beta1, const float beta2, const float beta1_correction,
const float beta2_correction, const float epsilon, const float lr, const float beta2_correction, const float epsilon, const float lr,
adamMode_t mode, const float decay) { adamMode_t mode, const float decay, const float div_scale) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
@ -79,6 +79,8 @@ struct AdamFunctor {
} }
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
if (div_scale > 0) r_g[ii] /= div_scale;
if (mode == ADAM_MODE_0) { // L2 if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]); r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
@ -116,8 +118,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const float lr, const float beta1, const float lr, const float beta1,
const float beta2, const float epsilon, const float beta2, const float epsilon,
const int step, const int mode, const int step, const int mode,
const int bias_correction, const int bias_correction, const float weight_decay,
const float weight_decay) { const float div_scale) {
using namespace at; using namespace at;
// Handle bias correction mode // Handle bias correction mode
@ -133,7 +135,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1, AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);) lr, (adamMode_t)mode, weight_decay, div_scale);)
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }

View File

@ -117,7 +117,7 @@ class CPUAdam(NVMeOptimizer):
data.addcdiv_(exp_avg, denom, value=-step_size) data.addcdiv_(exp_avg, denom, value=-step_size)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None, div_scale: float = -1):
loss = None loss = None
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
@ -152,9 +152,10 @@ class CPUAdam(NVMeOptimizer):
self._pre_update(p, 'exp_avg', 'exp_avg_sq') self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], -1) state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq') self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda': elif target_device.type == 'cuda':
assert div_scale == -1, "div_scale should remain default"
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"

View File

@ -81,7 +81,7 @@ class FusedAdam(torch.optim.Optimizer):
else: else:
super(FusedAdam, self).zero_grad() super(FusedAdam, self).zero_grad()
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None): def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale: float = -1):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
@ -137,6 +137,6 @@ class FusedAdam(torch.optim.Optimizer):
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction, beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction,
group['weight_decay']) group['weight_decay'], div_scale)
return loss return loss

View File

@ -89,7 +89,7 @@ class HybridAdam(NVMeOptimizer):
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None, div_scale: float = -1):
loss = None loss = None
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
@ -126,7 +126,7 @@ class HybridAdam(NVMeOptimizer):
self._pre_update(p, 'exp_avg', 'exp_avg_sq') self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], -1) state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq') self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda': elif target_device.type == 'cuda':
@ -146,6 +146,6 @@ class HybridAdam(NVMeOptimizer):
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode, group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode,
bias_correction, group['weight_decay']) bias_correction, group['weight_decay'], div_scale)
self._post_step() self._post_step()
return loss return loss

View File

@ -10,10 +10,12 @@ from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device from colossalai.utils import disposable, get_current_device
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class OptimState(Enum): class OptimState(Enum):
SCALED = 0 SCALED = 0
@ -62,6 +64,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
**defaults: Any): **defaults: Any):
super().__init__(optim) super().__init__(optim)
assert isinstance(module, ZeroDDP) assert isinstance(module, ZeroDDP)
assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list"
self.module = module self.module = module
self.gemini_manager = module.gemini_manager self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
@ -162,21 +165,24 @@ class ZeroOptimizer(ColossalaiOptimizer):
global_norm = math.sqrt(norm_sqr) global_norm = math.sqrt(norm_sqr)
return global_norm return global_norm
def _unscale_and_clip_grads(self): def _get_combined_scale(self):
assert self.optim_state == OptimState.SCALED loss_scale = 1
combined_scale = self.loss_scale if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
combined_scale = loss_scale
if self.clipping_flag: if self.clipping_flag:
total_norm = self._calc_global_norm() total_norm = self._calc_global_norm()
clip = ((total_norm / self.loss_scale) + 1e-6) / self.max_norm clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
if clip > 1: if clip > 1:
combined_scale = clip * self.loss_scale combined_scale = clip * loss_scale
for group in self.optim.param_groups: if combined_scale == 1:
for p in group['params']: return -1
if p.grad is not None: else:
p.grad.data.div_(combined_scale) return combined_scale
self.optim_state = OptimState.UNSCALED
@property @property
def loss_scale(self): def loss_scale(self):
@ -199,12 +205,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
self._update_fp16_params() self._update_fp16_params()
return return
# unscale grads if scaled # get combined scale. combined scale = loss scale * clipping norm
if self.optim_state == OptimState.SCALED: # so that gradient = gradient / combined scale
self._unscale_and_clip_grads() combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf) self.grad_scaler.update(found_inf)
ret = self.optim.step(*args, **kwargs) ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states() self._register_states()
self.zero_grad() self.zero_grad()
self._update_fp16_params() self._update_fp16_params()

View File

@ -71,7 +71,7 @@ def test_adam(adamw, step, p_dtype, g_dtype):
weight_decay = 0 weight_decay = 0
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw, multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
True, weight_decay) True, weight_decay, -1)
torch_adam_update( torch_adam_update(
step, step,