From e4f555f29a063e31fc9247c786a154cfd524f043 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 20 Jun 2022 11:19:38 +0800 Subject: [PATCH] [optim] refactor fused sgd (#1134) --- .../csrc/multi_tensor_sgd_kernel.cu | 71 +++----------- colossalai/nn/optimizer/fused_sgd.py | 95 ++++--------------- 2 files changed, 31 insertions(+), 135 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu index a077bc738..35f2c9b4e 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu @@ -28,10 +28,10 @@ * first run : necessary for proper momentum handling & init * wd_after_momentum : apply weight decay _after_ momentum instead of before **/ -template +template struct SGDFunctor { __device__ __forceinline__ void operator()( - int chunk_size, volatile int *noop_gmem, TensorListMetadata &tl, + int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl, float wd, float momentum, float dampening, float lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale) { // Early exit if we don't need to do anything @@ -50,12 +50,6 @@ struct SGDFunctor { T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc]; mom_in += chunk_idx * chunk_size; - at::Half *model_weights_out = nullptr; - if (N == 4) { - model_weights_out = (at::Half *)tl.addresses[3][tensor_loc]; - model_weights_out += chunk_idx * chunk_size; - } - n -= chunk_idx * chunk_size; // Non-divergent exit condition for the __syncthreads @@ -110,10 +104,6 @@ struct SGDFunctor { // adjust the weight and write out weight_in[i] += (-lr * incoming_grads[ii]); - // if necessary, write out an fp16 copy of the weights - if (N == 4) - model_weights_out[i] = static_cast(weight_in[i]); - // also write out the new momentum if (momentum != 0.f) mom_in[i] = incoming_moms[ii]; } @@ -131,20 +121,14 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, auto grad_type = tensor_lists[0][0].scalar_type(); auto weight_type = tensor_lists[1][0].scalar_type(); - if (num_tensors == 4) - for (int i = 0; i < tensor_lists[3].size(); i++) - TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, - "Additional output tensors should always be fp16."); - TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors"); // We have 3 possibilities to handle here, in terms of - // grad_type, param_type, momentum_type, requires_fp16_copy - // 1. fp16, fp16, fp16, No - // 2. fp32, fp32, fp32, No - // 3. fp16, fp32, fp32, Yes - // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case + // grad_type, param_type, momentum_type + // 1. fp16, fp16, fp16 + // 2. fp32, fp32, fp32 + // 3. fp16, fp32, fp32 // It's easier to hardcode these possibilities than to use // switches etc. to handle the cross-product of cases where // we don't want the majority of them. @@ -153,49 +137,22 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half && num_tensors == 3) { multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<3, at::Half, at::Half>(), wd, momentum, + SGDFunctor(), wd, momentum, dampening, lr, nesterov, first_run, wd_after_momentum, scale); } - // Case 2. fp16, fp32, fp32, No - // else if (grad_type == at::ScalarType::Half && - // weight_type == at::ScalarType::Float && - // num_tensors == 3) { - // multi_tensor_apply<3>( - // BLOCK_SIZE, - // chunk_size, - // noop_flag, - // tensor_lists, - // SGDFunctor<3, at::Half, float>(), - // wd, - // momentum, - // dampening, - // lr, - // nesterov, - // first_run, - // wd_after_momentum); - // } - // Case 2. fp32, fp32, fp32, No + // Case 2. fp32, fp32, fp32 else if (grad_type == at::ScalarType::Float && weight_type == at::ScalarType::Float && num_tensors == 3) { multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<3, float, float>(), wd, momentum, - dampening, lr, nesterov, first_run, wd_after_momentum, - scale); + SGDFunctor(), wd, momentum, dampening, + lr, nesterov, first_run, wd_after_momentum, scale); } - // Case 3. fp16, fp32, fp32, Yes + // Case 3. fp16, fp32, fp32 else if (grad_type == at::ScalarType::Half && - weight_type == at::ScalarType::Float && num_tensors == 4) { - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<4, at::Half, float>(), wd, momentum, - dampening, lr, nesterov, first_run, wd_after_momentum, - scale); - } - // Case 4. fp32, fp32, fp32, Yes - else if (grad_type == at::ScalarType::Float && - weight_type == at::ScalarType::Float && num_tensors == 4) { - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<4, float, float>(), wd, momentum, + weight_type == at::ScalarType::Float && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor(), wd, momentum, dampening, lr, nesterov, first_run, wd_after_momentum, scale); } else { diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 6149332ec..b948c5eaf 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -64,9 +64,7 @@ class FusedSGD(Optimizer): dampening=0, weight_decay=0, nesterov=False, - wd_after_momentum=False, - materialize_master_grads=True, - set_grad_none=False): + wd_after_momentum=False): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -80,10 +78,6 @@ class FusedSGD(Optimizer): super(FusedSGD, self).__init__(params, defaults) self.wd_after_momentum = wd_after_momentum - self.materialize_master_grads = materialize_master_grads - self.most_recent_scale = 1.0 - self.scale_set_by_backward = False - self.set_grad_none = set_grad_none if multi_tensor_applier.available: import colossal_C @@ -100,14 +94,6 @@ class FusedSGD(Optimizer): for group in self.param_groups: group.setdefault('nesterov', False) - def zero_grad(self): - if self.set_grad_none: - for group in self.param_groups: - for p in group['params']: - p.grad = None - else: - super(FusedSGD, self).zero_grad() - def get_momentums(self, params): momentums = [] first_run = True @@ -136,74 +122,27 @@ class FusedSGD(Optimizer): if closure is not None: loss = closure() - explicit_master_params = (hasattr(self, "_amp_stash") and hasattr(self._amp_stash, "fp32_from_fp16_groups")) - - for gid, group in enumerate(self.param_groups): + for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] # For each group, there are 3 possible combinations we need to consider: - # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy - # 1. fp16, fp16, fp16, No - # 2. fp32, fp32, fp32, No - # 3. fp16, fp32, fp32, Yes - - first_runs = [True, True] - - # I think a bit of code divergence in exchange for naming clarity is worthwhile - if explicit_master_params: - stash = self._amp_stash - - fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] - fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] - fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) - - if self.materialize_master_grads: - fp16_model_params = [ - p for i, p in enumerate(stash.fp16_groups[gid]) - if stash.fp32_from_fp16_groups[gid][i].grad is not None - ] - fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] - fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] - fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) - - fp16_set = [ - fp32_from_fp16_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_model_params - ] - else: - fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None] - fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None] - fp32_from_fp16_params = [ - p for i, p in enumerate(stash.fp32_from_fp16_groups[gid]) - if stash.fp16_groups[gid][i].grad is not None - ] - fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) - - fp16_set = [fp16_model_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_model_params] - - launch_sets = [fp16_set, [fp32_grads, fp32_params, fp32_momentums]] - else: - fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] - fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] - fp16_momentums, first_runs[0] = self.get_momentums(fp16_params) - - fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)] - fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)] - fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) - - launch_sets = [[fp16_grads, fp16_params, fp16_momentums], [fp32_grads, fp32_params, fp32_momentums]] - - for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)): - assert len(launch_set[0]) == len(launch_set[1]) - assert len(launch_set[0]) == len(launch_set[2]) - if len(launch_set[0]) > 0: - multi_tensor_applier(self.multi_tensor_sgd, self._dummy_overflow_buf, launch_set, weight_decay, - momentum, dampening, group['lr'], nesterov, first_run, self.wd_after_momentum, - 1.0 / self.most_recent_scale) - - self.most_recent_scale = 1.0 - self.scale_set_by_backward = False + # grad_type, param_to_update_type, momentum_type + # 1. fp16, fp16, fp16 + # 2. fp32, fp32, fp32 + # 3. fp16, fp32, fp32 + g_l, p_l = [], [] + for p in group['params']: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError('FusedSGD does not support sparse gradients') + g_l.append(p.grad) + p_l.append(p) + m_l, first_run = self.get_momentums(p_l) + multi_tensor_applier(self.multi_tensor_sgd, self._dummy_overflow_buf, [g_l, p_l, m_l], weight_decay, + momentum, dampening, group['lr'], nesterov, first_run, self.wd_after_momentum, 1.0) return loss