From a109225bc2292b3b03cd7324b14bca282c0115ab Mon Sep 17 00:00:00 2001
From: ver217 <lhx0217@gmail.com>
Date: Thu, 3 Mar 2022 15:06:18 +0800
Subject: [PATCH] add sharded adam

---
 .../zero/sharded_model/sharded_model_v2.py    |  19 +-
 colossalai/zero/sharded_optim/sharded_adam.py | 163 ++++++++++++++++++
 2 files changed, 175 insertions(+), 7 deletions(-)
 create mode 100644 colossalai/zero/sharded_optim/sharded_adam.py

diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py
index 36e3e4b30..7dbc99c3d 100644
--- a/colossalai/zero/sharded_model/sharded_model_v2.py
+++ b/colossalai/zero/sharded_model/sharded_model_v2.py
@@ -6,12 +6,13 @@ import torch.distributed as dist
 import torch.nn as nn
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
+from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
+                                       register_ophooks_recursively)
 from colossalai.engine.paramhooks import BaseParamHookMgr
 from colossalai.logging import get_dist_logger
-from colossalai.zero.sharded_param import ShardedParam
 from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
 from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
+from colossalai.zero.sharded_param import ShardedParam
 from torch.distributed import ProcessGroup
 from torch.nn.parameter import Parameter
 
@@ -64,10 +65,10 @@ class ShardedModelV2(nn.Module):
         self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
         # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
         # So we use 1.0 as the default gradient_predivide_factor
-        # However, if you set gradient_predivide_factor to None,
-        # we will set gradient_predivide_factor to a value >= 1.0 automatically
-        self.gradient_predivide_factor: float = \
-            gradient_predivide_factor if gradient_predivide_factor is not None else \
+        # However, if you set gradient_predivide_factor to None, we will set
+        # gradient_predivide_factor to a value >= 1.0 automatically
+        self.gradient_predivide_factor: float = gradient_predivide_factor if \
+            gradient_predivide_factor is not None else \
             get_gradient_predivide_factor(self.world_size)
         self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
 
@@ -83,6 +84,10 @@ class ShardedModelV2(nn.Module):
         loss.backward()
         self._final_backward_hook()
 
+    def backward_by_grad(self, tensor, grad):
+        torch.autograd.backward(tensors=tensor, grad_tensors=grad)
+        self._final_backward_hook()
+
     @torch.no_grad()
     def _final_backward_hook(self) -> None:
         if self._require_backward_grad_sync:
@@ -110,7 +115,7 @@ class ShardedModelV2(nn.Module):
         """
         At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
         full gradient for the local batch. The reduce-scatter op will save
-         a single shard of the summed gradient across all
+        a single shard of the summed gradient across all
         GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
 
             before reduce_scatter:
diff --git a/colossalai/zero/sharded_optim/sharded_adam.py b/colossalai/zero/sharded_optim/sharded_adam.py
new file mode 100644
index 000000000..6980b2b71
--- /dev/null
+++ b/colossalai/zero/sharded_optim/sharded_adam.py
@@ -0,0 +1,163 @@
+from enum import Enum
+from typing import Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.nn.optimizer import ColossalaiOptimizer
+from colossalai.zero.sharded_model import ShardedModelV2
+from torch import Tensor
+from torch.distributed import ProcessGroup
+from torch.optim import Optimizer
+
+from ._utils import has_inf_or_nan
+
+
+class OptimState(Enum):
+    SCALED = 1
+    UNSCALED = 2
+
+
+class ShardedAdam(ColossalaiOptimizer):
+
+    def __init__(self,
+                 adam_optim: Optimizer,
+                 sharded_model: nn.Module,
+                 cpu_offload: bool = False,
+                 initial_scale: float = 2**32,
+                 min_scale: float = 1,
+                 growth_factor: float = 2,
+                 backoff_factor: float = 0.5,
+                 growth_interval: float = 1000,
+                 hysteresis: float = 2,
+                 max_scale: int = 2**32,
+                 dp_process_group: Optional[ProcessGroup] = None,
+                 mp_process_group: Optional[ProcessGroup] = None) -> None:
+        super().__init__(adam_optim)
+        self.model: Union[nn.Module, ShardedModelV2] = sharded_model
+        self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
+        self.state_device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
+        self.optim_state: OptimState = OptimState.UNSCALED
+        self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
+        self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
+        # Grad scaler
+        self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
+                                             min_scale=min_scale,
+                                             growth_factor=growth_factor,
+                                             backoff_factor=backoff_factor,
+                                             growth_interval=growth_interval,
+                                             hysteresis=hysteresis,
+                                             max_scale=max_scale)
+        self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.state_device)
+
+        # Early state initialization
+        for group in adam_optim.param_groups:
+            for p in group['params']:
+                state_shape = p.shape
+                if hasattr(p, 'ca_attr'):
+                    assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
+                    # TODO: use payload shape
+                    state_shape = p.ca_attr.payload(self.state_device)
+                state = adam_optim.state[p]
+                assert len(state) == 0, 'adam optimizer initialized'
+                state['step'] = 0
+                # Exponential moving average of gradient values
+                state['exp_avg'] = torch.zeros(state_shape,
+                                               memory_format=torch.preserve_format,
+                                               dtype=torch.float,
+                                               device=self.state_device)
+                # Exponential moving average of squared gradient values
+                state['exp_avg_sq'] = torch.zeros(state_shape,
+                                                  memory_format=torch.preserve_format,
+                                                  dtype=torch.float,
+                                                  device=self.state_device)
+                if group['amsgrad']:
+                    # Maintains max of all exp. moving avg. of sq. grad. values
+                    state['max_exp_avg_sq'] = torch.zeros(state_shape,
+                                                          memory_format=torch.preserve_format,
+                                                          dtype=torch.float,
+                                                          device=self.state_device)
+
+    def step(self, *args, **kwargs):
+        # unscale grads if scaled
+        if self.optim_state == OptimState.SCALED:
+            self._unscale_grads()
+
+        found_inf = self._check_overflow()
+        self.grad_scaler.update(found_inf)
+
+        if found_inf:
+            self.zero_grad()
+            return
+
+        # Write payload back to p.data
+        for group in self.optim.param_groups:
+            for p in group['params']:
+                data = p.data
+                if hasattr(p, 'ca_attr'):
+                    data = p.ca_attr.payload(self.state_device)
+                if torch.is_floating_point(data) and data.dtype != torch.float:
+                    data = data.to(torch.float)
+                p.data = data
+        ret = self.optim.step(*args, **kwargs)
+        # Set p.data to None
+        for group in self.optim.param_groups:
+            for p in group['params']:
+                p.data = None
+        return ret
+
+    def backward(self, loss: Tensor) -> None:
+        loss = self.loss_scale * loss
+        self.optim_state = OptimState.SCALED
+        if self.model_is_sharded:
+            self.model.backward(loss)
+        else:
+            super().backward(loss)
+
+    def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
+        if self.model_is_sharded:
+            self.model.backward_by_grad(tensor, grad)
+        else:
+            super().backward_by_grad(tensor, grad)
+
+    def clip_grad_norm(self, model: nn.Module, max_norm: float):
+        if self.optim_state == OptimState.SCALED:
+            self._unscale_grads()
+        return super().clip_grad_norm(model, max_norm)
+
+    @property
+    def loss_scale(self):
+        return self.grad_scaler.scale
+
+    def _check_overflow(self):
+        # clear previous overflow record
+        self._found_overflow.fill_(0.0)
+
+        # check for overflow
+        for group in self.optim.param_groups:
+            for p in group['params']:
+                if has_inf_or_nan(p.grad):
+                    self._found_overflow.fill_(1.0)
+                    break
+
+        # all-reduce across dp group
+        dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.dp_process_group)
+
+        # all-reduce over model parallel group
+        dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.mp_process_group)
+
+        if self._found_overflow.item() > 0:
+            return True
+        else:
+            return False
+
+    def _unscale_grads(self):
+        assert self.optim_state == OptimState.SCALED
+        for group in self.optim.param_groups:
+            for p in group['params']:
+                if p.grad is not None:
+                    p.grad.data.div_(self.loss_scale)
+        self.optim_state = OptimState.UNSCALED