From e17e92c54d106f7cfbc7d480200b44e404049b99 Mon Sep 17 00:00:00 2001
From: Jiarui Fang <fangjiarui123@gmail.com>
Date: Thu, 3 Mar 2022 12:42:57 +0800
Subject: [PATCH] Polish sharded parameter (#297)

* init shard param from shape tuple

* add more unitest for shard param

* add more unittests to shareded param
---
 colossalai/zero/shard_param/__init__.py       |  3 -
 .../zero/sharded_model/sharded_model_v2.py    | 46 +++++++-------
 colossalai/zero/sharded_param/__init__.py     |  3 +
 .../sharded_param.py}                         | 61 ++++++++++++-------
 .../test_shard_param.py                       | 54 +++++++++++-----
 5 files changed, 106 insertions(+), 61 deletions(-)
 delete mode 100644 colossalai/zero/shard_param/__init__.py
 create mode 100644 colossalai/zero/sharded_param/__init__.py
 rename colossalai/zero/{shard_param/shard_param.py => sharded_param/sharded_param.py} (51%)

diff --git a/colossalai/zero/shard_param/__init__.py b/colossalai/zero/shard_param/__init__.py
deleted file mode 100644
index bd7f5e46b..000000000
--- a/colossalai/zero/shard_param/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .shard_param import ShardParam
-
-__all__ = ['ShardParam']
\ No newline at end of file
diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py
index a32afdff2..36e3e4b30 100644
--- a/colossalai/zero/sharded_model/sharded_model_v2.py
+++ b/colossalai/zero/sharded_model/sharded_model_v2.py
@@ -1,4 +1,3 @@
-
 import functools
 from typing import Any, Optional
 
@@ -7,11 +6,10 @@ import torch.distributed as dist
 import torch.nn as nn
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
-                                       register_ophooks_recursively)
+from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
 from colossalai.engine.paramhooks import BaseParamHookMgr
 from colossalai.logging import get_dist_logger
-from colossalai.zero.shard_param import ShardParam
+from colossalai.zero.sharded_param import ShardedParam
 from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
 from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
 from torch.distributed import ProcessGroup
@@ -21,17 +19,19 @@ from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
 
 
 class ShardedModelV2(nn.Module):
-    def __init__(self,
-                 module: nn.Module,
-                 process_group: Optional[ProcessGroup] = None,
-                 reduce_scatter_process_group: Optional[ProcessGroup] = None,
-                 reduce_scatter_bucket_size_mb: int = 25,
-                 reshard_after_forward: bool = True,
-                 mixed_precision: bool = False,
-                 fp32_reduce_scatter: bool = False,
-                 offload_config: Optional[dict] = None,
-                 gradient_predivide_factor: Optional[float] = 1.0,
-                 ):
+
+    def __init__(
+        self,
+        module: nn.Module,
+        process_group: Optional[ProcessGroup] = None,
+        reduce_scatter_process_group: Optional[ProcessGroup] = None,
+        reduce_scatter_bucket_size_mb: int = 25,
+        reshard_after_forward: bool = True,
+        mixed_precision: bool = False,
+        fp32_reduce_scatter: bool = False,
+        offload_config: Optional[dict] = None,
+        gradient_predivide_factor: Optional[float] = 1.0,
+    ):
         r"""
         A demo to reconfigure zero1 shared_model.
         Currently do not consider the Optimizer States.
@@ -49,7 +49,7 @@ class ShardedModelV2(nn.Module):
 
         # Shard the parameters at first
         for _, param in self.module.named_parameters():
-            param.ca_attr = ShardParam(param)
+            param.ca_attr = ShardedParam(param)
             param.ca_attr.shard()
             param._sharded_grad = ShardedGradient(param, self, offload_config)
 
@@ -64,8 +64,10 @@ class ShardedModelV2(nn.Module):
         self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
         # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
         # So we use 1.0 as the default gradient_predivide_factor
-        # However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
-        self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
+        # However, if you set gradient_predivide_factor to None,
+        # we will set gradient_predivide_factor to a value >= 1.0 automatically
+        self.gradient_predivide_factor: float = \
+            gradient_predivide_factor if gradient_predivide_factor is not None else \
             get_gradient_predivide_factor(self.world_size)
         self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
 
@@ -107,7 +109,8 @@ class ShardedModelV2(nn.Module):
     def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
         """
         At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
-        full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
+        full gradient for the local batch. The reduce-scatter op will save
+         a single shard of the summed gradient across all
         GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
 
             before reduce_scatter:
@@ -139,8 +142,9 @@ class ShardedModelV2(nn.Module):
             orig_grad_data = new_grad.data
             if self.world_size > 1:
                 grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
-                self.reducer.reduce_scatter_async(
-                    grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
+                self.reducer.reduce_scatter_async(grad_chunks,
+                                                  group=self.reduce_scatter_process_group,
+                                                  callback_fn=functools.partial(self._reduce_scatter_callback, param))
             else:
                 self._reduce_scatter_callback(param, new_grad)
             orig_grad_data.record_stream(self.comm_stream)
diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py
new file mode 100644
index 000000000..527cf11d6
--- /dev/null
+++ b/colossalai/zero/sharded_param/__init__.py
@@ -0,0 +1,3 @@
+from .sharded_param import ShardedParam
+
+__all__ = ['ShardedParam']
diff --git a/colossalai/zero/shard_param/shard_param.py b/colossalai/zero/sharded_param/sharded_param.py
similarity index 51%
rename from colossalai/zero/shard_param/shard_param.py
rename to colossalai/zero/sharded_param/sharded_param.py
index 7bc36470f..f7363d0a5 100644
--- a/colossalai/zero/shard_param/shard_param.py
+++ b/colossalai/zero/sharded_param/sharded_param.py
@@ -1,41 +1,59 @@
-from enum import Enum
-
 import torch
 import torch.distributed as dist
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
 from colossalai.zero.sharded_model._zero3_utils import get_shard
+from typing import Union, Tuple, Optional
+import numpy
 
 
-class TensorType(Enum):
-    GRAD = 1
-    DATA = 2
-
-
-class ShardParam(object):
+class ShardedParam(object):
     r"""
     A wrapper to torch.nn.Parameter. Shard a param
-    on different processes.
+    on memory space of different processes.
     """
 
-    def __init__(
-        self,
-        param: torch.nn.Parameter,
-        tensor_type: TensorType = TensorType.DATA,
-        process_group=None,
-    ) -> None:
+    def __init__(self,
+                 other: Union[torch.nn.Parameter, Tuple[int, ...]],
+                 process_group: Optional[dist.ProcessGroup] = None,
+                 is_sharded: bool = False,
+                 device: Optional[torch.device] = None) -> None:
+        r"""
+        other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
+        process_group: the process group storing the shared data.
+        is_sharded: is shared the param during __init__.
+        device: the device to place param data payload on
+        """
         self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
         self.world_size = dist.get_world_size(self.process_group)
         self.local_rank = dist.get_rank(self.process_group)
-        self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
-        self._payload_shape = None
-        self._payload_numel = None
-        self._origin_shape = param.shape
-        self._origin_numel = param.numel()
-        self._origin_dtype = param.dtype
         self.is_sharded = False
 
+        # Hijack the data payload of param
+        if isinstance(other, torch.nn.Parameter):
+            self._param_payload = other.data.to(device)
+            self._origin_shape = other.shape
+            self._origin_numel = other.numel()
+            if is_sharded:
+                self.shard()
+        elif isinstance(other, tuple):
+            self._origin_shape = other
+            self._origin_numel = numpy.prod(other)
+
+            # TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape.
+            assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple"
+            self._param_payload = torch.empty(self._origin_shape, device=device)
+            if is_sharded:
+                self.shard()
+        else:
+            raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}")
+
+        self._payload_numel = None
+
     def payload(self, target_device: torch.device):
+        r"""
+        get the payload and move it to target device
+        """
         return self._param_payload.to(target_device)
 
     def shard(self):
@@ -50,6 +68,7 @@ class ShardParam(object):
     def gather(self):
         r"""
         Collect the payload of param from different processes to process of local rank.
+        The payload has to be moved to cuda memory before communication.
         """
         if not self.is_sharded:
             return
diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py
index 9973ee524..642cd7f2b 100644
--- a/tests/test_zero_data_parallel/test_shard_param.py
+++ b/tests/test_zero_data_parallel/test_shard_param.py
@@ -1,50 +1,72 @@
 #!/usr/bin/env python
 # -*- encoding: utf-8 -*-
 
-from asyncio.log import logger
 from functools import partial
 
 import colossalai
 import pytest
 import torch
 import torch.multiprocessing as mp
-from colossalai.zero.shard_param import ShardParam
+from colossalai.zero.sharded_param import ShardedParam
 from colossalai.utils import free_port
 from colossalai.logging import get_dist_logger, disable_existing_loggers
 from tests.test_zero_data_parallel.common import Net, CONFIG
 
+
+def run_init_shard_param(rank, world_size, port):
+    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+    param = torch.nn.Parameter(data=torch.rand(2, 3))
+    sparam = ShardedParam(param, None, True)
+    payload = sparam.payload(torch.device('cuda'))
+    assert (list(payload.shape) == [3])
+    del sparam
+
+    param_shape = (2, 3)
+    sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
+    payload = sparam.payload(torch.device('cuda'))
+    assert (list(payload.shape) == [3])
+
+    param_shape = (2, 3)
+    sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
+    payload = sparam.payload(torch.device('cuda'))
+    assert (list(payload.shape) == [2, 3])
+
+
 def run_shard_param_check(rank, world_size, port):
-    colossalai.launch(config=CONFIG,
-                      rank=rank,
-                      world_size=world_size,
-                      host='localhost',
-                      port=port,
-                      backend='nccl')
-    
+    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+
     logger = get_dist_logger()
     model = Net()
 
     # add an attribute as ca_attr to hijack the access to param.data
     for _, param in model.named_parameters():
         numel_ref = (param.numel() + world_size - 1) // world_size
-        param.ca_attr = ShardParam(param)
+        param.ca_attr = ShardedParam(param)
         param.ca_attr.shard()
         param_data = param.ca_attr.payload(torch.device('cpu'))
-        logger.info(f'shard {param_data.shape} {param_data}', ranks = [1])
-        assert(numel_ref == param_data.numel())
+        assert (numel_ref == param_data.numel())
 
     for _, param in model.named_parameters():
         param.ca_attr.gather()
         param_data = param.ca_attr.payload(torch.device('cpu'))
-        logger.info(f'gather {param_data.shape} {param_data}', ranks = [1])
-    
+
     disable_existing_loggers([logger])
 
+
 @pytest.mark.dist
-def test_run_shard_shape():
+def test_shard_shape():
     world_size = 2
     run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
     mp.spawn(run_func, nprocs=world_size)
 
+
+@pytest.mark.dist
+def test_init_shard_param():
+    world_size = 2
+    run_func = partial(run_init_shard_param, world_size=world_size, port=free_port())
+    mp.spawn(run_func, nprocs=world_size)
+
+
 if __name__ == '__main__':
-    test_run_shard_shape()
\ No newline at end of file
+    test_shard_shape()
+    test_init_shard_param()