diff --git a/colossalai/_C/fused_optim.pyi b/colossalai/_C/fused_optim.pyi index 6d8e97dd9..983b02335 100644 --- a/colossalai/_C/fused_optim.pyi +++ b/colossalai/_C/fused_optim.pyi @@ -11,7 +11,7 @@ def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List ... -def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float) -> None: +def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float, div_scale: float) -> None: ... diff --git a/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp b/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp index a687adc7b..94f132521 100644 --- a/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp +++ b/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp @@ -17,8 +17,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, - const int bias_correction, - const float weight_decay); + const int bias_correction, const float weight_decay, + const float div_scale); void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, @@ -46,4 +46,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Computes and apply update for LAMB optimizer"); m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, "Computes L2 norm for a list of contiguous tensors"); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu index 891f23e4e..afd34bb96 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu @@ -28,7 +28,7 @@ struct AdamFunctor { int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, const float beta1, const float beta2, const float beta1_correction, const float beta2_correction, const float epsilon, const float lr, - adamMode_t mode, const float decay) { + adamMode_t mode, const float decay, const float div_scale) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -79,6 +79,8 @@ struct AdamFunctor { } #pragma unroll for (int ii = 0; ii < ILP; ii++) { + if (div_scale > 0) r_g[ii] /= div_scale; + if (mode == ADAM_MODE_0) { // L2 r_g[ii] = r_g[ii] + (decay * r_p[ii]); r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; @@ -116,8 +118,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, - const int bias_correction, - const float weight_decay) { + const int bias_correction, const float weight_decay, + const float div_scale) { using namespace at; // Handle bias correction mode @@ -133,7 +135,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamFunctor(), beta1, beta2, bias_correction1, bias_correction2, epsilon, - lr, (adamMode_t)mode, weight_decay);) + lr, (adamMode_t)mode, weight_decay, div_scale);) AT_CUDA_CHECK(cudaGetLastError()); -} \ No newline at end of file +} diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 745d8de22..5b05fecc8 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -117,7 +117,7 @@ class CPUAdam(NVMeOptimizer): data.addcdiv_(exp_avg, denom, value=-step_size) @torch.no_grad() - def step(self, closure=None): + def step(self, closure=None, div_scale: float = -1): loss = None if closure is not None: with torch.enable_grad(): @@ -152,9 +152,10 @@ class CPUAdam(NVMeOptimizer): self._pre_update(p, 'exp_avg', 'exp_avg_sq') self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], -1) + state['exp_avg_sq'], div_scale) self._post_update(p, 'exp_avg', 'exp_avg_sq') elif target_device.type == 'cuda': + assert div_scale == -1, "div_scale should remain default" assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 4687e6f3b..064e55a40 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -81,7 +81,7 @@ class FusedAdam(torch.optim.Optimizer): else: super(FusedAdam, self).zero_grad() - def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None): + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale: float = -1): """Performs a single optimization step. Arguments: @@ -137,6 +137,6 @@ class FusedAdam(torch.optim.Optimizer): multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction, - group['weight_decay']) + group['weight_decay'], div_scale) return loss diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 676dc71e4..a925c3d91 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -89,7 +89,7 @@ class HybridAdam(NVMeOptimizer): self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @torch.no_grad() - def step(self, closure=None): + def step(self, closure=None, div_scale: float = -1): loss = None if closure is not None: with torch.enable_grad(): @@ -126,7 +126,7 @@ class HybridAdam(NVMeOptimizer): self._pre_update(p, 'exp_avg', 'exp_avg_sq') self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], -1) + state['exp_avg_sq'], div_scale) self._post_update(p, 'exp_avg', 'exp_avg_sq') elif target_device.type == 'cuda': @@ -146,6 +146,6 @@ class HybridAdam(NVMeOptimizer): bias_correction = 1 if group['bias_correction'] else 0 multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode, - bias_correction, group['weight_decay']) + bias_correction, group['weight_decay'], div_scale) self._post_step() return loss diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py index 62a0be329..2786d4496 100644 --- a/colossalai/nn/optimizer/zero_optimizer.py +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -10,10 +10,12 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.nn.parallel.data_parallel import ZeroDDP from colossalai.utils import disposable, get_current_device +_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} + class OptimState(Enum): SCALED = 0 @@ -62,6 +64,7 @@ class ZeroOptimizer(ColossalaiOptimizer): **defaults: Any): super().__init__(optim) assert isinstance(module, ZeroDDP) + assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list" self.module = module self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager @@ -162,21 +165,24 @@ class ZeroOptimizer(ColossalaiOptimizer): global_norm = math.sqrt(norm_sqr) return global_norm - def _unscale_and_clip_grads(self): - assert self.optim_state == OptimState.SCALED + def _get_combined_scale(self): + loss_scale = 1 - combined_scale = self.loss_scale + if self.optim_state == OptimState.SCALED: + loss_scale = self.loss_scale + self.optim_state = OptimState.UNSCALED + + combined_scale = loss_scale if self.clipping_flag: total_norm = self._calc_global_norm() - clip = ((total_norm / self.loss_scale) + 1e-6) / self.max_norm + clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm if clip > 1: - combined_scale = clip * self.loss_scale + combined_scale = clip * loss_scale - for group in self.optim.param_groups: - for p in group['params']: - if p.grad is not None: - p.grad.data.div_(combined_scale) - self.optim_state = OptimState.UNSCALED + if combined_scale == 1: + return -1 + else: + return combined_scale @property def loss_scale(self): @@ -199,12 +205,12 @@ class ZeroOptimizer(ColossalaiOptimizer): self._update_fp16_params() return - # unscale grads if scaled - if self.optim_state == OptimState.SCALED: - self._unscale_and_clip_grads() + # get combined scale. combined scale = loss scale * clipping norm + # so that gradient = gradient / combined scale + combined_scale = self._get_combined_scale() self.grad_scaler.update(found_inf) - ret = self.optim.step(*args, **kwargs) + ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) self._register_states() self.zero_grad() self._update_fp16_params() diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index 2291b0ce6..d95a23702 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -71,7 +71,7 @@ def test_adam(adamw, step, p_dtype, g_dtype): weight_decay = 0 multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw, - True, weight_decay) + True, weight_decay, -1) torch_adam_update( step,