From e79ea44247bbb6457cd2a2c30454208017fd31ef Mon Sep 17 00:00:00 2001
From: Frank Lee <somerlee.9@gmail.com>
Date: Tue, 15 Mar 2022 10:05:38 +0800
Subject: [PATCH] [fp16] refactored fp16 optimizer (#392)

---
 colossalai/amp/naive_amp/__init__.py          |  21 +-
 colossalai/amp/naive_amp/_fp16_optimizer.py   | 473 +++++++-----------
 colossalai/amp/naive_amp/_utils.py            |  40 ++
 .../naive_amp/grad_scaler/base_grad_scaler.py |   2 -
 .../grad_scaler/dynamic_grad_scaler.py        |  14 +-
 colossalai/amp/naive_amp/naive_amp.py         |  10 +-
 colossalai/initialize.py                      |   2 +-
 .../zero/sharded_optim/sharded_optim.py       |  69 ++-
 .../zero/sharded_optim/sharded_optim_v2.py    |   2 +-
 tests/test_amp/test_naive_fp16.py             |  83 +++
 10 files changed, 371 insertions(+), 345 deletions(-)
 create mode 100644 colossalai/amp/naive_amp/_utils.py
 create mode 100644 tests/test_amp/test_naive_fp16.py

diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py
index 32ea3469a..2390c199e 100644
--- a/colossalai/amp/naive_amp/__init__.py
+++ b/colossalai/amp/naive_amp/__init__.py
@@ -1,13 +1,12 @@
+import inspect
 import torch.nn as nn
 from torch.optim import Optimizer
 from colossalai.utils import is_no_pp_or_last_stage
-
 from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
+from .grad_scaler import DynamicGradScaler, ConstantGradScaler
 
 
-def convert_to_naive_amp(model: nn.Module,
-                         optimizer: Optimizer,
-                         amp_config):
+def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
     """A helper function to wrap training components with naive AMP modules
 
     :param model: your model object
@@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module,
         output_to_fp32 = is_no_pp_or_last_stage()
         model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
 
-    optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
+    use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
+    if use_dynamic_grad_scaler:
+        scaler_class = DynamicGradScaler
+    else:
+        scaler_class = ConstantGradScaler
+
+    sig = inspect.signature(scaler_class.__init__)
+    kwargs = dict()
+    for param in sig.parameters.values():
+        if param.name in amp_config:
+            kwargs[param.name] = amp_config.pop(param.name)
+    grad_scaler = scaler_class(**kwargs)
+    optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
     return model, optimizer
 
 
diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py
index 01842590f..98bb1e639 100644
--- a/colossalai/amp/naive_amp/_fp16_optimizer.py
+++ b/colossalai/amp/naive_amp/_fp16_optimizer.py
@@ -2,6 +2,7 @@
 # -*- encoding: utf-8 -*-
 
 import torch
+import torch.distributed as dist
 
 try:
     import colossal_C
@@ -9,41 +10,30 @@ except:
     print('Colossalai should be built with cuda extension to use the FP16 optimizer')
 
 from torch.optim import Optimizer
-
-from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
+from colossalai.context import ParallelMode
 from colossalai.logging import get_dist_logger
-from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
-                              clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)
+from colossalai.utils import (copy_tensor_parallel_attributes, clip_grad_norm_fp32, multi_tensor_applier)
+from torch.distributed import ProcessGroup
+from .grad_scaler import BaseGradScaler
+from ._utils import has_inf_or_nan, zero_gard_by_list
 
-
-def _zero_grad_group_helper(group, set_to_none):
-    """Zero out the gradient for a group of parameters.
-    Note: copied from torch.optim.optimizer."""
-    for param in group:
-        if param.grad is not None:
-            if set_to_none:
-                param.grad = None
-            else:
-                if param.grad.grad_fn is not None:
-                    param.grad.detach_()
-                else:
-                    param.grad.requires_grad_(False)
-                param.grad.zero_()
+__all__ = ['FP16Optimizer']
 
 
 def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
-    """Use multi-tensor-applier to copy values from one list to another.
+    """
+    adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
+
+    Use multi-tensor-applier to copy values from one list to another.
     We don't have a blfoat16 implementation so for now if the overflow_buf
     is not provided, we default back to simple loop copy to be compatible
-    with bfloat16."""
+    with bfloat16.
+    """
     if overflow_buf:
         overflow_buf.fill_(0)
         # Scaling with factor `1.0` is equivalent to copy.
-        multi_tensor_applier(colossal_C.multi_tensor_scale,
-                             overflow_buf,
-                             [this, that],
-                             1.0)
+        multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
     else:
         for this_, that_ in zip(this, that):
             that_.copy_(this_)
@@ -111,8 +101,7 @@ class DynamicGradScaler:
             self._hysteresis_tracker -= 1
             # Now if we are out of hysteresis count, scale down the loss.
             if self._hysteresis_tracker <= 0:
-                self._scale = torch.max(self._scale * self.backoff_factor,
-                                        self.min_scale)
+                self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale)
             if self.verbose:
                 self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
         else:
@@ -127,12 +116,13 @@ class DynamicGradScaler:
                 if self._max_scale is not None and self._scale >= self._max_scale:
                     if self.verbose:
                         self._logger.info(
-                            f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
+                            f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed',
+                            ranks=[0])
                 else:
                     self._scale = self._scale * self.growth_factor
                     if self.verbose:
-                        self._logger.info(
-                            f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
+                        self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}',
+                                          ranks=[0])
 
     def state_dict(self):
         state_dict = {}
@@ -173,326 +163,241 @@ class FP16Optimizer(Optimizer):
     """
 
     def __init__(self,
-                 optimizer,
-                 clip_grad=0,
-                 log_num_zeros_in_grad=False,
-                 initial_scale=2 ** 32,
-                 min_scale=1,
-                 growth_factor=2,
-                 backoff_factor=0.5,
-                 growth_interval=1000,
-                 hysteresis=2,
-                 max_scale: int = 2 ** 32,
-                 verbose: bool = False):
-        # default args for compatibility
-        bf16 = False
-        params_have_main_grad = False
-
+                 optimizer: Optimizer,
+                 grad_scaler: BaseGradScaler,
+                 verbose: bool = False,
+                 clip_grad_norm=0,
+                 dp_process_group: ProcessGroup = None,
+                 mp_process_group: ProcessGroup = None):
         # have a defaults for compatibility with pytorch optim
-        self.defaults = optimizer.defaults
+        self._optimizer = optimizer
+        self._defaults = optimizer.defaults
 
-        # log config
-        self._logger = get_dist_logger()
-        if verbose:
-            self._logger.info(f"\n=========  FP16 Optimizer Config =========\n"
-                              f"Optimizer: {optimizer.__class__.__name__}\n"
-                              f"clip_grad = {clip_grad}\n"
-                              f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n"
-                              f"initial_scale = {initial_scale}\n"
-                              f"min_scale = {min_scale}\n"
-                              f"growth_factor = {growth_factor}\n"
-                              f"backoff_factor = {backoff_factor}\n"
-                              f"growth_interval = {growth_interval}\n"
-                              f"hysteresis = {hysteresis}\n"
-                              f"==========================================", ranks=[0])
+        # fp16-related params
+        assert isinstance(grad_scaler, BaseGradScaler)
+        self._grad_scaler = grad_scaler
+        self._found_overflow = torch.cuda.FloatTensor([0.0])
+        self._dummy_overflow_buf = torch.cuda.IntTensor([0])
 
-        """Input optimizer is the base optimizer for example Adam."""
-        self.optimizer = optimizer
-        assert self.optimizer, 'no optimizer is provided.'
-        # Set gradient clipping and logging params.
-        self.clip_grad = clip_grad
-        self.log_num_zeros_in_grad = log_num_zeros_in_grad
-        self.params_have_main_grad = params_have_main_grad
+        # misc params
+        self._clip_grad_max_norm = clip_grad_norm
 
-        self.bf16 = bf16
-        self.grad_scaler = DynamicGradScaler(
-            initial_scale=initial_scale,
-            min_scale=min_scale,
-            growth_factor=growth_factor,
-            backoff_factor=backoff_factor,
-            growth_interval=growth_interval,
-            hysteresis=hysteresis,
-            max_scale=max_scale,
-            verbose=verbose
-        )
+        # get process group
+        def _get_process_group(parallel_mode):
+            if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA):
+                return gpc.get_group(ParallelMode.DATA)
+            else:
+                return None
 
-        # None grad scaler is only supported for bf16.
-        if self.grad_scaler is None:
-            assert self.bf16, 'fp16 expects a grad scaler.'
+        if dp_process_group is None:
+            dp_process_group = _get_process_group(ParallelMode.DATA)
+        if mp_process_group is None:
+            mp_process_group = _get_process_group(ParallelMode.MODEL)
 
-        # Tensor used to determine if a nan/if has happend.
-        # Any non-zero value indicates inf/nan.
-        # Note that we keep this for the cases that grad scaler is none.
-        # We still record nan/inf if we have a bfloat16 with a grad scaler.
-        if self.grad_scaler:
-            self.found_inf = torch.cuda.FloatTensor([0.0])
+        self._dp_process_group = dp_process_group
+        self._mp_process_group = mp_process_group
 
-        # Dummy tensor needed for apex multi-apply tensor.
-        # For bfloat, we don't have multi-tensor apply and for now
-        # we set it to none so the multi-tensor apply gets ignored.
-        if bf16:
-            self._dummy_overflow_buf = None
-        else:
-            self._dummy_overflow_buf = torch.cuda.IntTensor([0])
-
-        # In case grad scaler is not passed, define the unity scale.
-        if self.grad_scaler is None:
-            self._scale_one = torch.cuda.FloatTensor([1.0])
-
-        # ======================
-        # main parameter stuff
-        # ======================
-
-        # Three groups of parameters:
-        #   float16_groups: original float16 parameters
-        #   fp32_from_float16_groups: fp32 copy of float16 parameters
-        #   fp32_from_fp32_groups: original fp32 parameters
-        self.float16_groups = []
-        self.fp32_from_float16_groups = []
-        self.fp32_from_fp32_groups = []
+        # we maintain three groups of parameters
+        # so that the model can have a mixture
+        # of fp16 and fp32 params
+        # fp16_param_groups: the fp16 params of the model
+        # fp32_master_param_groups: the fp32 params cast from the fp16 param of the model
+        # fp32_param_groups: the fp32 params of the model
+        # NOTE:
+        # 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence
+        # 2. fp32_param_groups and fp16_param_groups are exclusive of each other
+        self._fp16_param_groups = []
+        self._fp32_master_param_groups = []
+        self._fp32_param_groups = []
 
         # For all the groups in the original optimizer:
-        for param_group in self.optimizer.param_groups:
-            float16_params_this_group = []
-            fp32_params_this_group = []
-            fp32_from_float16_params_this_group = []
+        for param_group in self._optimizer.param_groups:
+            fp16_params = []
+            fp32_master_params = []
+            fp32_params = []
             # For all the parameters in this group:
             for i, param in enumerate(param_group['params']):
                 if param.requires_grad:
                     # float16 params:
-                    if param.type() in ['torch.cuda.HalfTensor',
-                                        'torch.cuda.BFloat16Tensor']:
-                        float16_params_this_group.append(param)
-                        # Create a copy
-                        main_param = param.detach().clone().float()
-                        # Copy tensor model parallel attributes.
-                        copy_tensor_parallel_attributes(param, main_param)
+                    if param.type() in ['torch.cuda.HalfTensor']:
+                        fp16_params.append(param)
 
-                        # if hasattr(param, 'shared'):
-                        #     main_param.shared = param.shared
+                        # Create a fp32 copy
+                        fp32_param = param.detach().clone().float()
+                        # Copy tensor model parallel attributes.
+                        copy_tensor_parallel_attributes(param, fp32_param)
 
                         # Replace the optimizer params with the new fp32 copy.
-                        param_group['params'][i] = main_param
-                        fp32_from_float16_params_this_group.append(main_param)
+                        param_group['params'][i] = fp32_param
+                        fp32_master_params.append(fp32_param)
+
                         # Reset existing state dict key to the new main param.
-                        if param in self.optimizer.state:
-                            self.optimizer.state[main_param] \
-                                = self.optimizer.state.pop(param)
+                        if param in self._optimizer.state:
+                            self._optimizer.state[fp32_param] = self._optimizer.state.pop(param)
 
                     # fp32 params.
                     elif param.type() == 'torch.cuda.FloatTensor':
-                        fp32_params_this_group.append(param)
-                        param_group['params'][i] = param
+                        fp32_params.append(param)
                     else:
-                        raise TypeError('Wrapped parameters must be one of '
-                                        'torch.cuda.FloatTensor,  '
-                                        'torch.cuda.HalfTensor, or '
-                                        'torch.cuda.BFloat16Tensor. '
-                                        'Received {}'.format(param.type()))
+                        raise TypeError('Expected parameter of type torch.cuda.FloatTensor '
+                                        f'or torch.cuda.HalfTensor, but got {param.type()}')
 
-            self.float16_groups.append(float16_params_this_group)
-            self.fp32_from_float16_groups.append(
-                fp32_from_float16_params_this_group)
-            self.fp32_from_fp32_groups.append(fp32_params_this_group)
+            self._fp16_param_groups.append(fp16_params)
+            self._fp32_master_param_groups.append(fp32_master_params)
+            self._fp32_param_groups.append(fp32_params)
 
         # Leverage state_dict() and load_state_dict() to
         # recast preexisting per-param state tensors
-        self.optimizer.load_state_dict(self.optimizer.state_dict())
+        self._optimizer.load_state_dict(self._optimizer.state_dict())
 
-    def zero_grad(self, set_to_none=False):
-        """We only need to zero the model related parameters, i.e.,
-                float16_groups & fp32_from_fp32_groups."""
-        for group in self.float16_groups:
-            _zero_grad_group_helper(group, set_to_none)
-        for group in self.fp32_from_fp32_groups:
-            _zero_grad_group_helper(group, set_to_none)
+        # log config
+        self._logger = get_dist_logger()
+        if verbose:
+            self._logger.info(
+                f"\n=========  FP16 Optimizer Config =========\n"
+                f"Optimizer: {optimizer.__class__.__name__}\n"
+                f"clip_grad_norm = {clip_grad_norm}\n"
+                f"grad_scaler = {self._grad_scaler.__class__.__name__}"
+                f"==========================================",
+                ranks=[0])
 
-    def get_loss_scale(self):
-        if self.grad_scaler is None:
-            return self._scale_one
-        return self.grad_scaler.scale
+    @property
+    def grad_scaler(self):
+        return self._grad_scaler
 
-    def _copy_model_grads_to_main_grads(self):
+    @property
+    def loss_scale(self):
+        return self._grad_scaler.scale
+
+    @property
+    def optimizer(self):
+        return self._optimizer
+
+    @property
+    def defaults(self):
+        return self._defaults
+
+    def _check_overflow(self):
+        # clear previous overflow record
+        self._found_overflow.fill_(0.0)
+
+        # check for overflow
+        for group in self._optimizer.param_groups:
+            for p in group['params']:
+                if has_inf_or_nan(p.grad):
+                    self._found_overflow.fill_(1.0)
+                    break
+
+        # all-reduce across dp group
+        if self._dp_process_group:
+            dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group)
+
+        # all-reduce over model parallel group
+        if self._mp_process_group:
+            dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group)
+
+        return self._found_overflow.item() > 0
+
+    def zero_grad(self, set_to_none=True):
+        # set_to_none = True can save some memory space
+        for param_group in self._optimizer.param_groups:
+            zero_gard_by_list(param_group['params'], set_to_none=set_to_none)
+
+    def _get_fp32_param_groups_to_update(self):
+        return self._fp32_master_param_groups + self._fp32_param_groups
+
+    def _unscale_grads(self):
+        for group in self._get_fp32_param_groups_to_update():
+            for p in group:
+                if p.grad is not None:
+                    p.grad.data.div_(self.loss_scale)
+
+    def _assign_grad_to_fp32_master_param(self):
         # This only needs to be done for the float16 group.
-        for model_group, main_group in zip(self.float16_groups,
-                                           self.fp32_from_float16_groups):
-            for model_param, main_param in zip(model_group, main_group):
-                if self.params_have_main_grad:
-                    main_param.grad = model_param.main_grad.float()
-                else:
-                    if model_param.grad is not None:
-                        main_param.grad = model_param.grad.float()
+        for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
+            for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group):
+                fp32_param.grad = fp16_param.grad.float()
+                # clear unneeded grad on fp16 param
+                fp16_param.grad = None
 
-        # For fp32 grads, we need to reset the grads to main grad.
-        if self.params_have_main_grad:
-            for model_group in self.fp32_from_fp32_groups:
-                for model_param in model_group:
-                    model_param.grad = model_param.main_grad
-
-    def _unscale_main_grads_and_check_for_nan(self):
-        main_grads = []
-        # fp32 params fromm float16 ones.
-        for main_group in self.fp32_from_float16_groups:
-            for main_param in main_group:
-                if main_param.grad is not None:
-                    main_grads.append(main_param.grad.data)
-        # Append fp32 parameters.
-        for main_group in self.fp32_from_fp32_groups:
-            for main_param in main_group:
-                if main_param.grad is not None:
-                    main_grads.append(main_param.grad.data)
-        # Reset found inf.
-        self.found_inf.fill_(0.0)
-        # Unscale and set found inf/nan
-        torch._amp_foreach_non_finite_check_and_unscale_(
-            main_grads, self.found_inf, self.grad_scaler.inv_scale)
-        # Update across all model parallel instances.
-        torch.distributed.all_reduce(self.found_inf,
-                                     op=torch.distributed.ReduceOp.MAX,
-                                     group=gpc.get_group(ParallelMode.MODEL))
-
-        # Check for nan.
-        found_inf_flag = (self.found_inf.item() > 0)
-        return found_inf_flag
-
-    def _get_model_and_main_params_data_float16(self):
-        model_data = []
-        main_data = []
-        for model_group, main_group in zip(self.float16_groups,
-                                           self.fp32_from_float16_groups):
-            for model_param, main_param in zip(model_group, main_group):
-                model_data.append(model_param.data)
-                main_data.append(main_param.data)
-        return model_data, main_data
-
-    def _copy_main_params_to_model_params(self):
-        # Only needed for the float16 params.
-        model_data, main_data = self._get_model_and_main_params_data_float16()
-        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
+    def _update_fp16_param_from_fp32_param(self):
+        fp16_param_data = []
+        fp32_master_param_data = []
+        for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
+            for fp16_param, fp32_param in zip(fp16_group, fp32_group):
+                fp16_param_data.append(fp16_param.data)
+                fp32_master_param_data.append(fp32_param.data)
+        _multi_tensor_copy_this_to_that(this=fp32_master_param_data,
+                                        that=fp16_param_data,
                                         overflow_buf=self._dummy_overflow_buf)
 
-    def _copy_model_params_to_main_params(self):
-        # Only needed for the float16 params.
-        model_data, main_data = self._get_model_and_main_params_data_float16()
-        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
-                                        overflow_buf=self._dummy_overflow_buf)
-
-    def reload_model_params(self):
-        self._copy_model_params_to_main_params()
-
-    @torch.no_grad()
     def step(self):
         # Copy gradients from model params to main params.
-        self._copy_model_grads_to_main_grads()
+        self._assign_grad_to_fp32_master_param()
+        self._unscale_grads()
 
-        # Do unscale, check for inf, and update grad scaler only for
-        # the case that grad scaler is provided.
-        if self.grad_scaler:
+        overflow = self._check_overflow()
+        self._grad_scaler.update(overflow)
 
-            # Unscale and check for inf/nan.
-            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
-
-            # We are done with scaling gradients
-            # so we can update the loss scale.
-            self.grad_scaler.update(found_inf_flag)
-
-            # If we found inf/nan, skip the update.
-            if found_inf_flag:
-                return False, None, None
+        if overflow:
+            self.zero_grad()
+            return False, None
 
         # Clip the main gradients.
         grad_norm = None
-        if self.clip_grad > 0.0:
-            grad_norm = self.clip_grad_norm(self.clip_grad)
-
-        # count the zeros in the grads
-        num_zeros_in_grad = self.count_zeros() if \
-            self.log_num_zeros_in_grad else None
+        if self._clip_grad_max_norm > 0.0:
+            grad_norm = self.clip_grad_norm(self._clip_grad_max_norm)
 
         # Step the optimizer.
-        self.optimizer.step()
+        self._optimizer.step()
 
         # Update params from main params.
-        self._copy_main_params_to_model_params()
+        self._update_fp16_param_from_fp32_param()
 
         # Successful update.
-        return True, grad_norm, num_zeros_in_grad
+        return True, grad_norm
+
+    def backward(self, loss):
+        scaled_loss = loss * self.grad_scaler.scale
+        scaled_loss.backward()
 
     def state_dict(self):
         state_dict = {}
-        state_dict['optimizer'] = self.optimizer.state_dict()
+        state_dict['optimizer'] = self._optimizer.state_dict()
         if self.grad_scaler:
             state_dict['grad_scaler'] = self.grad_scaler.state_dict()
-        state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
+        state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups
         return state_dict
 
     def load_state_dict(self, state_dict):
         # Optimizer.
-        optimizer_key = 'optimizer'
-        if optimizer_key not in state_dict:
-            optimizer_key = 'optimizer_state_dict'
-            print_rank_0('***WARNING*** loading optimizer from '
-                         'an old checkpoint ...')
-        self.optimizer.load_state_dict(state_dict[optimizer_key])
+        self._optimizer.load_state_dict(state_dict['optimizer'])
 
         # Grad scaler.
-        if 'grad_scaler' not in state_dict:
-            print_rank_0('***WARNING*** found an old checkpoint, will not '
-                         'load grad scaler ...')
-        else:
-            if self.grad_scaler:
-                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
-            else:
-                print_rank_0('***WARNING*** fould the grad scaler in the '
-                             'checkpoint but it is None in the class. '
-                             'Skipping loading grad scaler ...')
+        if 'grad_scaler' in state_dict:
+            self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
 
         # Copy data for the main params.
-        fp32_from_float16_params_key = 'fp32_from_fp16_params'
-        if fp32_from_float16_params_key not in state_dict:
-            fp32_from_float16_params_key = 'fp32_from_fp16'
-        for current_group, saved_group in zip(
-                self.fp32_from_float16_groups,
-                state_dict[fp32_from_float16_params_key]):
-            for current_param, saved_param in zip(current_group, saved_group):
-                current_param.data.copy_(saved_param.data)
-
-    def get_parameters(self):
-        params = []
-        for param_group in self.optimizer.param_groups:
-            for param in param_group['params']:
-                params.append(param)
-        return params
+        if 'fp32_master_param_groups' in state_dict:
+            for current_group, ckpt_group in zip(self._fp32_master_param_groups,
+                                                 state_dict['fp32_master_param_groups']):
+                for current_param, ckpt_param in zip(current_group, ckpt_group):
+                    current_param.data.copy_(ckpt_param.data)
 
     def clip_grad_norm(self, clip_grad):
-        params = self.get_parameters()
+        params = []
+        for param_group in self._optimizer.param_groups:
+            for param in param_group['params']:
+                params.append(param)
         return clip_grad_norm_fp32(params, clip_grad)
 
-    def count_zeros(self):
-        params = self.get_parameters()
-        return count_zeros_fp32(params)
-
-    def scale_loss(self, loss):
-        """Simple scaling."""
-        return self.get_loss_scale() * loss
-
     # Promote state so it can be retrieved or set via
     # "optimizer_instance.state"
     def _get_state(self):
-        return self.optimizer.state
+        return self._optimizer.state
 
     def _set_state(self, value):
-        self.optimizer.state = value
+        self._optimizer.state = value
 
     state = property(_get_state, _set_state)
 
@@ -500,9 +405,9 @@ class FP16Optimizer(Optimizer):
     # "optimizer_instance.param_groups"
     # (for example, to adjust the learning rate)
     def _get_param_groups(self):
-        return self.optimizer.param_groups
+        return self._optimizer.param_groups
 
     def _set_param_groups(self, value):
-        self.optimizer.param_groups = value
+        self._optimizer.param_groups = value
 
     param_groups = property(_get_param_groups, _set_param_groups)
diff --git a/colossalai/amp/naive_amp/_utils.py b/colossalai/amp/naive_amp/_utils.py
new file mode 100644
index 000000000..5d87135a8
--- /dev/null
+++ b/colossalai/amp/naive_amp/_utils.py
@@ -0,0 +1,40 @@
+from typing import List
+from torch import Tensor
+
+
+def has_inf_or_nan(tensor):
+    try:
+        # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
+        # Pytorch's .sum() creates a one-element tensor of the same type as tensor
+        # (which is true for some recent version of pytorch).
+        tensor_sum = float(tensor.float().sum())
+        # More efficient version that can be used if .sum() returns a Python scalar
+        # tensor_sum = float(tensor.sum())
+    except RuntimeError as instance:
+        # We want to check if inst is actually an overflow exception.
+        # RuntimeError could come from a different error.
+        # If so, we still want the exception to propagate.
+        if "value cannot be converted" not in instance.args[0]:
+            raise
+        return True
+    else:
+        if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
+            return True
+        return False
+
+
+def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
+    """
+    Clear the gradient of a list of tensors,
+    Note: copied from torch.optim.optimizer.
+    """
+    for param in tensor_list:
+        if param.grad is not None:
+            if set_to_none:
+                param.grad = None
+            else:
+                if param.grad.grad_fn is not None:
+                    param.grad.detach_()
+                else:
+                    param.grad.requires_grad_(False)
+                param.grad.zero_()
diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
index fb279baf6..2d3e3700d 100644
--- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
@@ -28,12 +28,10 @@ class BaseGradScaler(ABC):
     def inv_scale(self) -> Tensor:
         return self._scale.double().reciprocal().float()
 
-    @abstractmethod
     def state_dict(self) -> Dict:
         state_dict = dict()
         state_dict['scale'] = self.scale
 
-    @abstractmethod
     def load_state_dict(self, state_dict: Dict) -> None:
         self._scale = state_dict['scale']
 
diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
index 79fd0f3a3..49f155f06 100644
--- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
@@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler):
                  growth_interval: int = 1000,
                  min_scale: int = None,
                  max_scale: int = None,
-                 hysteresis: int = None,
+                 hysteresis: int = 2,
                  verbose: bool = False):
         super().__init__(initial_scale, verbose)
-        self._min_scale = min_scale
-        self._max_scale = max_scale
+        if min_scale:
+            self._min_scale = torch.cuda.FloatTensor([min_scale])
+        else:
+            self._min_scale = None
+
+        if max_scale:
+            self._max_scale = torch.cuda.FloatTensor([max_scale])
+        else:
+            self._max_scale = None
+
         self._growth_factor = growth_factor
         self._backoff_factor = backoff_factor
         self._growth_interval = growth_interval
diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/amp/naive_amp/naive_amp.py
index c4e950f68..1ee34931f 100644
--- a/colossalai/amp/naive_amp/naive_amp.py
+++ b/colossalai/amp/naive_amp/naive_amp.py
@@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
     """
 
     def __init__(self, optim: Optimizer, *args, **kwargs):
-        optim = FP16Optimizer(optimizer=optim, *args, **kwargs)
+        optim = FP16Optimizer(optim, *args, **kwargs)
         super().__init__(optim)
 
     def backward(self, loss: Tensor):
-        """Backward with gradient scaler
-
-        :param loss: loss computed by a loss function
-        :type loss: torch.Tensor
-        """
-        loss = self.optim.scale_loss(loss)
-        loss.backward()
+        self.optim.backward(loss)
 
     def step(self):
         return self.optim.step()
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index d87f9658b..011859881 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -304,7 +304,7 @@ def initialize(model: nn.Module,
         if is_using_pp():
             assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
         if amp_mode == AMP_TYPE.NAIVE:
-            cfg_['clip_grad'] = clip_grad_norm
+            cfg_['clip_grad_norm'] = clip_grad_norm
         model, optimizer, criterion = convert_to_amp(model=model,
                                                      optimizer=optimizer,
                                                      criterion=criterion,
diff --git a/colossalai/zero/sharded_optim/sharded_optim.py b/colossalai/zero/sharded_optim/sharded_optim.py
index 9dff355db..2ea2feaf6 100644
--- a/colossalai/zero/sharded_optim/sharded_optim.py
+++ b/colossalai/zero/sharded_optim/sharded_optim.py
@@ -1,4 +1,3 @@
-from itertools import groupby
 from colossalai.utils.cuda import get_current_device
 import torch
 import torch.distributed as dist
@@ -7,7 +6,7 @@ from torch.optim import Optimizer
 from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket
 from colossalai.context import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
+from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
 from colossalai.nn.optimizer import ColossalaiOptimizer
 from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor,
                      release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan)
@@ -16,38 +15,26 @@ from functools import partial
 
 class ShardedOptimizer(ColossalaiOptimizer):
 
-    def __init__(
-            self,
-            optimizer: Optimizer,
-
-            # grad scaler config
-            initial_scale=2**32,
-            min_scale=1,
-            growth_factor=2,
-            backoff_factor=0.5,
-            growth_interval=1000,
-            hysteresis=2,
-            max_scale: int = 2**32,
-
-            # grad clipping
-            clip_grad_norm=2.0,
-            verbose=False,
-
-            # communication
-            reduce_bucket_size=500000000,
-            communication_dtype=torch.float16,
-            overlap_communication=False,
-
-            # stage 2
-            partition_grad=False,
-            
-            dp_parallel_mode=ParallelMode.DATA,
-            mp_parallel_mode=ParallelMode.MODEL,
-            
-            # cpu offload
-            cpu_offload=False,
-            cpu_fp16_param=False,
-            cpu_fp16_grad=False):
+    def __init__(self,
+                 optimizer: Optimizer,
+                 initial_scale=2**32,
+                 min_scale=1,
+                 growth_factor=2,
+                 backoff_factor=0.5,
+                 growth_interval=1000,
+                 hysteresis=2,
+                 max_scale: int = 2**32,
+                 clip_grad_norm=2.0,
+                 verbose=False,
+                 reduce_bucket_size=500000000,
+                 communication_dtype=torch.float16,
+                 overlap_communication=False,
+                 partition_grad=False,
+                 dp_parallel_mode=ParallelMode.DATA,
+                 mp_parallel_mode=ParallelMode.MODEL,
+                 cpu_offload=False,
+                 cpu_fp16_param=False,
+                 cpu_fp16_grad=False):
 
         # TODO: add support for
         # 1. fp16 master weights
@@ -257,12 +244,13 @@ class ShardedOptimizer(ColossalaiOptimizer):
                         reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
                                                  param=param,
                                                  reduce_rank=reduce_rank)
-                        
+
                         # define hook
                         # NOT IMPORTANT BUT GOOD TO KNOW:
                         # args here is not grad, but allow_unreacable and accumulate_grad
                         def reduce_grad_hook(*args):
                             reduction_func()
+
                         accum_grad_obj.register_hook(reduce_grad_hook)
 
                     _define_and_attach(param, reduce_rank)
@@ -293,8 +281,8 @@ class ShardedOptimizer(ColossalaiOptimizer):
     def _reduce_grads_in_bucket(self, reduce_rank=None):
         # reduce grads
         self._reduce_grads_by_rank(reduce_rank=reduce_rank,
-                                    grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
-                                    bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
+                                   grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
+                                   bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
 
         # use communication stream if overlapping
         # communication with computation
@@ -323,7 +311,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
                 # we do not keep the gradient after reduction
                 if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
                     if self._overlap_communication:
-                        # we need to keep this gradient for now as reduction may 
+                        # we need to keep this gradient for now as reduction may
                         # be completed yet since it is using a different cuda stream
                         self._param_store.add_previous_reduced_param(param)
                     else:
@@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer):
             self._grad_store._averaged_gradients[group_id] = []
             self._grad_store._averaged_gradients[group_id] = []
 
-
         # unscale and clip grads
         global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
         self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
@@ -501,7 +488,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
     def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
         # compute combined scale factor for this group
         combined_scale = self.loss_scale
-        
+
         if self._clip_grad_norm > 0.:
             # norm is in fact norm*scale
             clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
@@ -562,7 +549,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
                 for param in param_group:
                     if param.grad is not None:
                         self._reduce_and_remove_grads_by_bucket(param)
-        
+
         # we need to reduce the gradients
         # left in the communication bucket
         self._reduce_grads_in_bucket()
diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py
index d78ac3ecc..14b670a88 100644
--- a/colossalai/zero/sharded_optim/sharded_optim_v2.py
+++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py
@@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union
 import torch
 import torch.distributed as dist
 import torch.nn as nn
-from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
+from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
 from colossalai.nn.optimizer import ColossalaiOptimizer
diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py
new file mode 100644
index 000000000..c777d1587
--- /dev/null
+++ b/tests/test_amp/test_naive_fp16.py
@@ -0,0 +1,83 @@
+import torch
+import colossalai
+import copy
+import pytest
+import torch.multiprocessing as mp
+from colossalai.amp import convert_to_naive_amp
+from tests.components_to_test.registry import non_distributed_component_funcs
+from colossalai.utils import free_port
+from functools import partial
+
+
+def check_equal(a, b):
+    """
+    This function checks if two tensors are equal within tolerance
+    """
+    assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
+
+
+def run_naive_amp():
+    """
+    In this test, we compare the naive fp16 optimizer implemented in colossalai 
+    and fp32 torch optimizer
+    """
+
+    # create layer
+    test_models = ['repeated_computed_layers', 'nested_model']
+    for test_name in test_models:
+        get_component_func = non_distributed_component_funcs.get_callable(test_name)
+        model_builder, train_dataloader, _, optim_builder, _ = get_component_func()
+
+        # create model
+        amp_model = model_builder(checkpoint=True).cuda()
+        torch_model = copy.deepcopy(amp_model)
+
+        # create optimizer
+        amp_optimizer = optim_builder(amp_model)
+        torch_optimizer = optim_builder(torch_model)
+
+        # inject naive amp
+        amp_config = dict(initial_scale=1)
+        amp_model, amp_optimizer = convert_to_naive_amp(amp_model, amp_optimizer, amp_config)
+
+        # create data
+        data_iter = iter(train_dataloader)
+        data, label = next(data_iter)
+        data = data.cuda()
+
+        # forward pass
+        amp_output = amp_model(data)
+        torch_output = torch_model(data)
+        assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}'
+
+        # backward
+        amp_optimizer.backward(amp_output.mean())
+        torch_output.mean().backward()
+
+        # check grad
+        for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
+            torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3)
+
+        # step
+        amp_optimizer.step()
+        torch_optimizer.step()
+
+        # check updated param
+        for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
+            torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3)
+
+
+def run_dist(rank, world_size, port):
+    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+    run_naive_amp()
+
+
+@pytest.mark.dist
+def test_naive_amp():
+    world_size = 1
+    run_func = partial(run_dist, world_size=world_size, port=free_port())
+    mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+    test_naive_amp()