From 85f933b58b74a892130e679f554a6ead76b456e9 Mon Sep 17 00:00:00 2001
From: Jiarui Fang <fangjiarui123@gmail.com>
Date: Thu, 14 Jul 2022 16:57:48 +0800
Subject: [PATCH] [Optimizer] Remove useless ColoOptimizer (#1312)

---
 colossalai/nn/optimizer/__init__.py       |  4 +-
 colossalai/nn/optimizer/colo_optimizer.py | 80 -----------------------
 colossalai/tensor/colo_parameter.py       |  1 -
 tests/test_tensor/test_model.py           |  9 +--
 tests/test_utils/test_colo_checkpoint.py  |  4 +-
 5 files changed, 8 insertions(+), 90 deletions(-)
 delete mode 100644 colossalai/nn/optimizer/colo_optimizer.py

diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py
index f9a2bc98f..14cb01c24 100644
--- a/colossalai/nn/optimizer/__init__.py
+++ b/colossalai/nn/optimizer/__init__.py
@@ -7,9 +7,7 @@ from .lamb import Lamb
 from .lars import Lars
 from .cpu_adam import CPUAdam
 from .hybrid_adam import HybridAdam
-from .colo_optimizer import ColoOptimizer
 
 __all__ = [
-    'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam',
-    'CPU_ADAM_CNT', 'ColoOptimizer'
+    'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT'
 ]
diff --git a/colossalai/nn/optimizer/colo_optimizer.py b/colossalai/nn/optimizer/colo_optimizer.py
deleted file mode 100644
index 72ac91682..000000000
--- a/colossalai/nn/optimizer/colo_optimizer.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from typing import List, Union, Mapping, Dict, Any
-
-import torch.optim as optim
-from torch import Tensor
-from colossalai.tensor.colo_tensor import ColoTensor
-
-
-class ColoOptimizer(optim.Optimizer):
-
-    def __init__(self, named_params: Mapping[str, Union[Tensor, ColoTensor]], optimizer_class, *optimizer_args,
-                 **optimizer_kwargs):
-        """
-        ColoOptimizer collects all tensors in type of ColoTensor and torch.Tensor,
-        then use these tensors as ``params`` for optimizers
-
-        Args:
-            named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict
-                of parameters, where key is the parameter key, value is either
-                Tensor or ColoTensor. This usually used in
-                conjunction with model.named_parameters(), the same as PyTorch.
-            optimizer_class (torch.optim.Optimizer): the Optimizer to use
-                locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc.
-            *optimizer_args: the arguments to initialize the optimizer.
-            **optimizer_kwargs: the key-word arguments to initialize the optimizer.
-
-        """
-        self._optim = optimizer_class([p for n, p in named_params], *optimizer_args, **optimizer_kwargs)
-        self.param_groups = self._optim.param_groups
-        self.state = self._optim.state
-
-    def zero_grad(self, set_to_none: bool = False):    # type: ignore[override]
-        r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
-
-        Args:
-            set_to_none (bool): instead of setting to zero, set the grads to None.
-                This will in general have lower memory footprint, and can modestly improve performance.
-                However, it changes certain behaviors. For example:
-                1. When the user tries to access a gradient and perform manual ops on it,
-                a None attribute or a Tensor full of 0s will behave differently.
-                2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
-                are guaranteed to be None for params that did not receive a gradient.
-                3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
-                (in one case it does the step with a gradient of 0 and in the other it skips
-                the step altogether).
-        """
-        self._optim.zero_grad(set_to_none)
-
-    def step(self, closure=None):
-        r"""Performs a single optimization step (parameter update).
-
-        Args:
-            closure (callable): A closure that reevaluates the model and
-                returns the loss. Optional for most optimizers.
-
-        .. note::
-            Unless otherwise specified, this function should not modify the
-            ``.grad`` field of the parameters.
-        """
-        self._optim.step(closure)
-
-    def state_dict(self) -> Dict[str, Any]:
-        """
-        Returned state and param_groups will contain parameter keys
-        instead of parameter indices like torch.optim.Optimizer.
-        """
-        return self._optim.state_dict()
-
-    def load_state_dict(self, state_dict: Mapping[str, Any]):
-        r"""Loads the ColoOptimizer state.
-
-        Args:
-            state_dict (dict): ColoOptimizer state. Should be an object returned
-                from a call to :meth:`state_dict`.
-        """
-        self._optim.load_state_dict(state_dict)
-
-    def add_param_group(self, param_group: Any):
-        r"""Add a new param group
-        """
-        self._optim.add_param_group(param_group)
diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py
index 8963d2194..17c326516 100644
--- a/colossalai/tensor/colo_parameter.py
+++ b/colossalai/tensor/colo_parameter.py
@@ -1,7 +1,6 @@
 import torch
 
 from typing import Optional
-from copy import copy
 
 from colossalai.tensor.colo_tensor import ColoTensor
 from colossalai.tensor.const import TensorType
diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py
index ee5edae2c..a442f6ad7 100644
--- a/tests/test_tensor/test_model.py
+++ b/tests/test_tensor/test_model.py
@@ -12,7 +12,7 @@ from colossalai.utils.cuda import get_current_device
 from colossalai.utils import free_port
 from colossalai.utils.model.colo_init_context import ColoInitContext
 from colossalai.tensor import ColoTensor, ProcessGroup
-from colossalai.nn.optimizer import ColoOptimizer
+from colossalai.nn.optimizer import ColossalaiOptimizer
 
 from tests.components_to_test.registry import non_distributed_component_funcs
 from _utils import split_param_row_tp1d, split_param_col_tp1d
@@ -33,7 +33,8 @@ def run_1d_hybrid_tp(model_name):
     if rank == 0:
         model_torch = model_builder(checkpoint=True)
         model_torch = model_torch.cuda()
-        optimizer_torch = ColoOptimizer(model_torch.named_parameters(), torch.optim.SGD, lr=0.1)
+
+        optimizer_torch = ColossalaiOptimizer(torch.optim.SGD(model_torch.parameters(), lr=0.1))
 
         # Make two models have the same init params
         for p1, p2 in zip(model.parameters(), model_torch.parameters()):
@@ -80,7 +81,7 @@ def run_1d_hybrid_tp(model_name):
     if rank == 0:
         model_torch.train()
 
-    colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1)
+    colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
 
     for i, (data, label) in enumerate(train_dataloader):
 
@@ -170,7 +171,7 @@ def test_colo_optimizer():
     with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
         model = model_builder(checkpoint=True)
 
-    colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1)
+    colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
     for i, (data, label) in enumerate(train_dataloader):
         colo_optimizer.zero_grad()
         data = data.to(get_current_device())
diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py
index 13f54aefe..edc463b0d 100644
--- a/tests/test_utils/test_colo_checkpoint.py
+++ b/tests/test_utils/test_colo_checkpoint.py
@@ -18,7 +18,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
 from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec
 from colossalai.nn.parallel.data_parallel import ColoDDP
 from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
-from colossalai.nn.optimizer import ColoOptimizer
+from colossalai.nn.optimizer import ColossalaiOptimizer
 
 from tests.components_to_test.registry import non_distributed_component_funcs
 
@@ -117,7 +117,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
     model_reload = model_reload.cuda()
     model_reload.train()
 
-    colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1)
+    colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1))
 
     for i, (data, label) in enumerate(train_dataloader):