From 0ce8924cebd419d4638efa22f9327774104cfee4 Mon Sep 17 00:00:00 2001
From: Jiarui Fang <fangjiarui123@gmail.com>
Date: Thu, 21 Apr 2022 14:15:48 +0800
Subject: [PATCH] [tensor] reorganize files (#820)

---
 colossalai/gemini/tensor/_ops/__init__.py     |  3 --
 colossalai/gemini/tensor/api.py               | 17 ---------
 colossalai/tensor/__init__.py                 |  7 ++++
 colossalai/tensor/_ops/__init__.py            |  3 ++
 .../{gemini => }/tensor/_ops/element_wise.py  | 14 ++++----
 colossalai/{gemini => }/tensor/_ops/init.py   |  6 ++--
 colossalai/{gemini => }/tensor/_ops/linear.py | 12 +++----
 .../colo_tensor.py}                           | 21 ++++++-----
 .../__init__.py => tensor/op_wrapper.py}      | 35 +++++++++++++------
 colossalai/{gemini => }/tensor/utils.py       | 14 +++-----
 .../test_tensor.py => test_tensor/test_op.py} | 15 +++-----
 11 files changed, 71 insertions(+), 76 deletions(-)
 delete mode 100644 colossalai/gemini/tensor/_ops/__init__.py
 delete mode 100644 colossalai/gemini/tensor/api.py
 create mode 100644 colossalai/tensor/__init__.py
 create mode 100644 colossalai/tensor/_ops/__init__.py
 rename colossalai/{gemini => }/tensor/_ops/element_wise.py (64%)
 rename colossalai/{gemini => }/tensor/_ops/init.py (83%)
 rename colossalai/{gemini => }/tensor/_ops/linear.py (70%)
 rename colossalai/{gemini/tensor/stateful_tensor.py => tensor/colo_tensor.py} (51%)
 rename colossalai/{gemini/tensor/__init__.py => tensor/op_wrapper.py} (52%)
 rename colossalai/{gemini => }/tensor/utils.py (64%)
 rename tests/{test_gemini/test_tensor.py => test_tensor/test_op.py} (74%)

diff --git a/colossalai/gemini/tensor/_ops/__init__.py b/colossalai/gemini/tensor/_ops/__init__.py
deleted file mode 100644
index 199f456ee..000000000
--- a/colossalai/gemini/tensor/_ops/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .init import stateful_uniform
-from .linear import stateful_linear
-from .element_wise import stateful_mean
\ No newline at end of file
diff --git a/colossalai/gemini/tensor/api.py b/colossalai/gemini/tensor/api.py
deleted file mode 100644
index 92a7e98fb..000000000
--- a/colossalai/gemini/tensor/api.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from typing import (
-    Callable,
-    Dict,
-)
-
-# Custom sharded ops
-_STATEFUL_OPS: Dict[str, Callable] = {}
-
-
-def _register_stateful_op(op, func):
-    from inspect import signature
-    if len(signature(func).parameters) != 4:
-        raise TypeError(f'Custom stateful op function expects signature: '
-                        f'(types, args, kwargs, process_group), but received '
-                        f'signature: {signature(func)}')
-    global _STATEFUL_OPS
-    _STATEFUL_OPS[op] = func
diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py
new file mode 100644
index 000000000..157da5db6
--- /dev/null
+++ b/colossalai/tensor/__init__.py
@@ -0,0 +1,7 @@
+from .op_wrapper import (
+    colo_op_impl,)
+from .colo_tensor import ColoTensor
+from .utils import convert_parameter
+from ._ops import *
+
+__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl']
diff --git a/colossalai/tensor/_ops/__init__.py b/colossalai/tensor/_ops/__init__.py
new file mode 100644
index 000000000..0fb96d9fa
--- /dev/null
+++ b/colossalai/tensor/_ops/__init__.py
@@ -0,0 +1,3 @@
+from .init import colo_uniform
+from .linear import colo_linear
+from .element_wise import colo_mean
\ No newline at end of file
diff --git a/colossalai/gemini/tensor/_ops/element_wise.py b/colossalai/tensor/_ops/element_wise.py
similarity index 64%
rename from colossalai/gemini/tensor/_ops/element_wise.py
rename to colossalai/tensor/_ops/element_wise.py
index 773ce4799..1843784e6 100644
--- a/colossalai/gemini/tensor/_ops/element_wise.py
+++ b/colossalai/tensor/_ops/element_wise.py
@@ -1,17 +1,17 @@
 import torch
-from colossalai.gemini.tensor import stateful_op_impl
-from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
+from colossalai.tensor.op_wrapper import colo_op_impl
+from colossalai.tensor import ColoTensor
 
 
-@stateful_op_impl(torch.mean)
-def stateful_mean(types, args=(), kwargs=None, pg=None):
+@colo_op_impl(torch.mean)
+def colo_mean(types, args=(), kwargs=None, pg=None):
     stateful_tensor = args[0]
     return torch.mean(stateful_tensor.torch_tensor())
 
 
 def register_elementwise_op(op):
 
-    @stateful_op_impl(op)
+    @colo_op_impl(op)
     def elementwise_op(types, args=(), kwargs=None, pg=None):
         """
         Handles ``__torch_function__`` dispatch for the elementwise op such
@@ -20,8 +20,8 @@ def register_elementwise_op(op):
         """
         input_tensor = args[0]
         # Validate types
-        if not isinstance(input_tensor, StatefulTensorV2):
-            raise TypeError("input needs to be a StatefulTensorV2")
+        if not isinstance(input_tensor, ColoTensor):
+            raise TypeError("input needs to be a ColoTensor")
         return op(input_tensor.torch_tensor())
 
 
diff --git a/colossalai/gemini/tensor/_ops/init.py b/colossalai/tensor/_ops/init.py
similarity index 83%
rename from colossalai/gemini/tensor/_ops/init.py
rename to colossalai/tensor/_ops/init.py
index 079ffe7c3..7d4b2cceb 100644
--- a/colossalai/gemini/tensor/_ops/init.py
+++ b/colossalai/tensor/_ops/init.py
@@ -1,5 +1,5 @@
 import torch
-from colossalai.gemini.tensor import stateful_op_impl
+from colossalai.tensor.op_wrapper import colo_op_impl
 
 
 def validate_param(param, param_name):
@@ -7,8 +7,8 @@ def validate_param(param, param_name):
         raise ValueError(f"param: {param_name} shouldn't be None!")
 
 
-@stateful_op_impl(torch.nn.init.uniform_)
-def stateful_uniform(types, args=(), kwargs=None, pg=None):
+@colo_op_impl(torch.nn.init.uniform_)
+def colo_uniform(types, args=(), kwargs=None, pg=None):
     r"""
     Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
     distribution :math:`\mathcal{U}(a, b)`.
diff --git a/colossalai/gemini/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py
similarity index 70%
rename from colossalai/gemini/tensor/_ops/linear.py
rename to colossalai/tensor/_ops/linear.py
index 7998e353d..e75f18609 100644
--- a/colossalai/gemini/tensor/_ops/linear.py
+++ b/colossalai/tensor/_ops/linear.py
@@ -1,11 +1,11 @@
 import torch
-from colossalai.gemini.tensor import stateful_op_impl
-from ..stateful_tensor import StatefulTensorV2
+from colossalai.tensor.op_wrapper import colo_op_impl
+from colossalai.tensor.colo_tensor import ColoTensor
 from packaging import version
 
 
-@stateful_op_impl(torch.nn.functional.linear)
-def stateful_linear(types, args, kwargs, pg):
+@colo_op_impl(torch.nn.functional.linear)
+def colo_linear(types, args, kwargs, pg):
     """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
     This method computes a linear.
     """
@@ -19,11 +19,11 @@ def stateful_linear(types, args, kwargs, pg):
             bias = None
     else:
         bias = kwargs.get('bias', None)
-        if isinstance(bias, StatefulTensorV2):
+        if isinstance(bias, ColoTensor):
             bias = bias.torch_tensor()
 
     # Add communication logic before and after linear call.
-    if isinstance(weight, StatefulTensorV2):
+    if isinstance(weight, ColoTensor):
         return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
     else:
         return torch.nn.functional.linear(input_tensor, weight, bias)
diff --git a/colossalai/gemini/tensor/stateful_tensor.py b/colossalai/tensor/colo_tensor.py
similarity index 51%
rename from colossalai/gemini/tensor/stateful_tensor.py
rename to colossalai/tensor/colo_tensor.py
index dbfd088b2..47e693720 100644
--- a/colossalai/gemini/tensor/stateful_tensor.py
+++ b/colossalai/tensor/colo_tensor.py
@@ -1,11 +1,11 @@
 import torch
-from .api import _STATEFUL_OPS
+from .op_wrapper import _COLOSSAL_OPS
 
 
-class StatefulTensorV2(object):
+class ColoTensor(object):
 
     def __new__(cls, *args, **kwargs):
-        return super(StatefulTensorV2, cls).__new__(cls)
+        return super(ColoTensor, cls).__new__(cls)
 
     def __init__(self, t: torch.Tensor) -> None:
         self._torch_tensor = t
@@ -15,16 +15,15 @@ class StatefulTensorV2(object):
 
     @classmethod
     def __torch_function__(cls, func, types, args=(), kwargs=None):
-        global _STATEFUL_OPS
-        if func in _STATEFUL_OPS:
-            # Find StatefulTensorV2 instance to get process_group.
+        global _COLOSSAL_OPS
+        if func in _COLOSSAL_OPS:
             for arg in args:
-                if isinstance(arg, StatefulTensorV2):
-                    return _STATEFUL_OPS[func](types, args, kwargs, None)
+                if isinstance(arg, ColoTensor):
+                    return _COLOSSAL_OPS[func](types, args, kwargs, None)
 
             for kwarg in kwargs.values():
-                if isinstance(kwarg, StatefulTensorV2):
-                    return _STATEFUL_OPS[func](types, args, kwargs, None)
+                if isinstance(kwarg, ColoTensor):
+                    return _COLOSSAL_OPS[func](types, args, kwargs, None)
 
         raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and "
-                           f"kwargs: {kwargs} not supported for StatefulTensorV2!")
+                           f"kwargs: {kwargs} not supported for ColoTensor!")
diff --git a/colossalai/gemini/tensor/__init__.py b/colossalai/tensor/op_wrapper.py
similarity index 52%
rename from colossalai/gemini/tensor/__init__.py
rename to colossalai/tensor/op_wrapper.py
index fcf909ba4..577c85353 100644
--- a/colossalai/gemini/tensor/__init__.py
+++ b/colossalai/tensor/op_wrapper.py
@@ -1,24 +1,39 @@
+from typing import (
+    Callable,
+    Dict,
+)
 import functools
-from .api import (
-    _register_stateful_op,)
+
+# Custom sharded ops
+_COLOSSAL_OPS: Dict[str, Callable] = {}
 
 
-def stateful_op_impl(func):
+def _register_colo_op(op, func):
+    from inspect import signature
+    if len(signature(func).parameters) != 4:
+        raise TypeError(f'Custom stateful op function expects signature: '
+                        f'(types, args, kwargs, process_group), but received '
+                        f'signature: {signature(func)}')
+    global _COLOSSAL_OPS
+    _COLOSSAL_OPS[op] = func
+
+
+def colo_op_impl(func):
     """
     Provides a way for users to write their own custom operator. This
-    can be used to override existing StatefulTensorV2 operators or write a new
-    one not supported by StatefulTensorV2. If the operator in question is covered
-    by ``__torch_function__`` dispatch and has a StatefulTensorV2 as any of its
+    can be used to override existing ColoTensor operators or write a new
+    one not supported by ColoTensor. If the operator in question is covered
+    by ``__torch_function__`` dispatch and has a ColoTensor as any of its
     parameters, the function provided will be invoked for that operator.
 
     Example::
-        >>> @stateful_op_impl(torch.nn.functional.linear)
+        >>> @colo_op_impl(torch.nn.functional.linear)
         >>> def my_custom_linear(types, args, kwargs, process_group):
         >>>   ....
         >>>
         >>> input = torch.rand(10, 32)
-        >>> weight = StatefulTensorV2(torch.rand(32, 16))
-        >>> bias = StatefulTensorV2(torch.rand(16))
+        >>> weight = ColoTensor(torch.rand(32, 16))
+        >>> bias = ColoTensor(torch.rand(16))
         >>> # This will call `my_custom_linear` instead of the default.
         >>> torch.nn.functional.linear(input, weight, bias)
 
@@ -32,7 +47,7 @@ def stateful_op_impl(func):
     """
 
     def decorator_sharded_func(wrapped_func):
-        _register_stateful_op(func, wrapped_func)
+        _register_colo_op(func, wrapped_func)
 
         @functools.wraps(wrapped_func)
         def wrapper(*args, **kwargs):
diff --git a/colossalai/gemini/tensor/utils.py b/colossalai/tensor/utils.py
similarity index 64%
rename from colossalai/gemini/tensor/utils.py
rename to colossalai/tensor/utils.py
index 869d1ad1c..1430e5191 100644
--- a/colossalai/gemini/tensor/utils.py
+++ b/colossalai/tensor/utils.py
@@ -1,14 +1,10 @@
 import torch
-import torch.distributed as dist
-from torch.distributed import distributed_c10d
 
-from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
+from colossalai.tensor.colo_tensor import ColoTensor
 
 
-def _convert_tensor(tensor: torch.Tensor) -> StatefulTensorV2:
-    if not tensor.is_contiguous():
-        raise ValueError('input tensor is not a contiguous Tensor')
-    return StatefulTensorV2(tensor)
+def _convert_tensor(tensor: torch.Tensor) -> ColoTensor:
+    return ColoTensor(tensor)
 
 
 def convert_parameter(module: torch.nn.Module, param_name: str):
@@ -26,10 +22,10 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
 
     st = _convert_tensor(tensor)
 
-    # Replace param with StatefulTensorV2.
+    # Replace param with ColoTensor.
 
     # Need to delete the attribute first since param_name might be
-    # torch.nn.Parameter and can't be replaced with StatefulTensorV2 which is
+    # torch.nn.Parameter and can't be replaced with ColoTensor which is
     # not torch.nn.Parameter.
     delattr(module, param_name)
 
diff --git a/tests/test_gemini/test_tensor.py b/tests/test_tensor/test_op.py
similarity index 74%
rename from tests/test_gemini/test_tensor.py
rename to tests/test_tensor/test_op.py
index f403df5b4..4c9e72a92 100644
--- a/tests/test_gemini/test_tensor.py
+++ b/tests/test_tensor/test_op.py
@@ -1,10 +1,6 @@
 from numpy import allclose
 import torch
-from torch import nn
-from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
-# TODO(jiaruifang) auto import
-from colossalai.gemini.tensor._ops import *
-from colossalai.gemini.tensor.api import _STATEFUL_OPS
+from colossalai.tensor import ColoTensor
 from copy import deepcopy
 
 
@@ -18,8 +14,8 @@ def test_linear():
     input_ref = torch.randn(1, in_dim)
     input_tensor = input_ref.clone()
 
-    sharded_weight = StatefulTensorV2(fc_ref.weight)
-    sharded_bias = StatefulTensorV2(fc_ref.bias)
+    sharded_weight = ColoTensor(fc_ref.weight)
+    sharded_bias = ColoTensor(fc_ref.bias)
 
     # replace the torch nn.Parameters with ShardedTensor
     delattr(fc, 'weight')
@@ -45,15 +41,14 @@ def test_linear():
 
 # The test case failed
 # def test_uniform():
-#     t = StatefulTensorV2(torch.zeros(3, 5))
-#     # print(_STATEFUL_OPS)
+#     t = ColoTensor(torch.zeros(3, 5))
 #     torch.nn.init.uniform_(t)
 #     print(t)
 
 
 def test_element_wise():
     t_ref = torch.randn(3, 5)
-    t = StatefulTensorV2(t_ref.clone())
+    t = ColoTensor(t_ref.clone())
     assert torch.mean(t) == torch.mean(t_ref)
     assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
     assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))