From 0ce8924cebd419d4638efa22f9327774104cfee4 Mon Sep 17 00:00:00 2001 From: Jiarui Fang <fangjiarui123@gmail.com> Date: Thu, 21 Apr 2022 14:15:48 +0800 Subject: [PATCH] [tensor] reorganize files (#820) --- colossalai/gemini/tensor/_ops/__init__.py | 3 -- colossalai/gemini/tensor/api.py | 17 --------- colossalai/tensor/__init__.py | 7 ++++ colossalai/tensor/_ops/__init__.py | 3 ++ .../{gemini => }/tensor/_ops/element_wise.py | 14 ++++---- colossalai/{gemini => }/tensor/_ops/init.py | 6 ++-- colossalai/{gemini => }/tensor/_ops/linear.py | 12 +++---- .../colo_tensor.py} | 21 ++++++----- .../__init__.py => tensor/op_wrapper.py} | 35 +++++++++++++------ colossalai/{gemini => }/tensor/utils.py | 14 +++----- .../test_tensor.py => test_tensor/test_op.py} | 15 +++----- 11 files changed, 71 insertions(+), 76 deletions(-) delete mode 100644 colossalai/gemini/tensor/_ops/__init__.py delete mode 100644 colossalai/gemini/tensor/api.py create mode 100644 colossalai/tensor/__init__.py create mode 100644 colossalai/tensor/_ops/__init__.py rename colossalai/{gemini => }/tensor/_ops/element_wise.py (64%) rename colossalai/{gemini => }/tensor/_ops/init.py (83%) rename colossalai/{gemini => }/tensor/_ops/linear.py (70%) rename colossalai/{gemini/tensor/stateful_tensor.py => tensor/colo_tensor.py} (51%) rename colossalai/{gemini/tensor/__init__.py => tensor/op_wrapper.py} (52%) rename colossalai/{gemini => }/tensor/utils.py (64%) rename tests/{test_gemini/test_tensor.py => test_tensor/test_op.py} (74%) diff --git a/colossalai/gemini/tensor/_ops/__init__.py b/colossalai/gemini/tensor/_ops/__init__.py deleted file mode 100644 index 199f456ee..000000000 --- a/colossalai/gemini/tensor/_ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .init import stateful_uniform -from .linear import stateful_linear -from .element_wise import stateful_mean \ No newline at end of file diff --git a/colossalai/gemini/tensor/api.py b/colossalai/gemini/tensor/api.py deleted file mode 100644 index 92a7e98fb..000000000 --- a/colossalai/gemini/tensor/api.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import ( - Callable, - Dict, -) - -# Custom sharded ops -_STATEFUL_OPS: Dict[str, Callable] = {} - - -def _register_stateful_op(op, func): - from inspect import signature - if len(signature(func).parameters) != 4: - raise TypeError(f'Custom stateful op function expects signature: ' - f'(types, args, kwargs, process_group), but received ' - f'signature: {signature(func)}') - global _STATEFUL_OPS - _STATEFUL_OPS[op] = func diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py new file mode 100644 index 000000000..157da5db6 --- /dev/null +++ b/colossalai/tensor/__init__.py @@ -0,0 +1,7 @@ +from .op_wrapper import ( + colo_op_impl,) +from .colo_tensor import ColoTensor +from .utils import convert_parameter +from ._ops import * + +__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl'] diff --git a/colossalai/tensor/_ops/__init__.py b/colossalai/tensor/_ops/__init__.py new file mode 100644 index 000000000..0fb96d9fa --- /dev/null +++ b/colossalai/tensor/_ops/__init__.py @@ -0,0 +1,3 @@ +from .init import colo_uniform +from .linear import colo_linear +from .element_wise import colo_mean \ No newline at end of file diff --git a/colossalai/gemini/tensor/_ops/element_wise.py b/colossalai/tensor/_ops/element_wise.py similarity index 64% rename from colossalai/gemini/tensor/_ops/element_wise.py rename to colossalai/tensor/_ops/element_wise.py index 773ce4799..1843784e6 100644 --- a/colossalai/gemini/tensor/_ops/element_wise.py +++ b/colossalai/tensor/_ops/element_wise.py @@ -1,17 +1,17 @@ import torch -from colossalai.gemini.tensor import stateful_op_impl -from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2 +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ColoTensor -@stateful_op_impl(torch.mean) -def stateful_mean(types, args=(), kwargs=None, pg=None): +@colo_op_impl(torch.mean) +def colo_mean(types, args=(), kwargs=None, pg=None): stateful_tensor = args[0] return torch.mean(stateful_tensor.torch_tensor()) def register_elementwise_op(op): - @stateful_op_impl(op) + @colo_op_impl(op) def elementwise_op(types, args=(), kwargs=None, pg=None): """ Handles ``__torch_function__`` dispatch for the elementwise op such @@ -20,8 +20,8 @@ def register_elementwise_op(op): """ input_tensor = args[0] # Validate types - if not isinstance(input_tensor, StatefulTensorV2): - raise TypeError("input needs to be a StatefulTensorV2") + if not isinstance(input_tensor, ColoTensor): + raise TypeError("input needs to be a ColoTensor") return op(input_tensor.torch_tensor()) diff --git a/colossalai/gemini/tensor/_ops/init.py b/colossalai/tensor/_ops/init.py similarity index 83% rename from colossalai/gemini/tensor/_ops/init.py rename to colossalai/tensor/_ops/init.py index 079ffe7c3..7d4b2cceb 100644 --- a/colossalai/gemini/tensor/_ops/init.py +++ b/colossalai/tensor/_ops/init.py @@ -1,5 +1,5 @@ import torch -from colossalai.gemini.tensor import stateful_op_impl +from colossalai.tensor.op_wrapper import colo_op_impl def validate_param(param, param_name): @@ -7,8 +7,8 @@ def validate_param(param, param_name): raise ValueError(f"param: {param_name} shouldn't be None!") -@stateful_op_impl(torch.nn.init.uniform_) -def stateful_uniform(types, args=(), kwargs=None, pg=None): +@colo_op_impl(torch.nn.init.uniform_) +def colo_uniform(types, args=(), kwargs=None, pg=None): r""" Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform distribution :math:`\mathcal{U}(a, b)`. diff --git a/colossalai/gemini/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py similarity index 70% rename from colossalai/gemini/tensor/_ops/linear.py rename to colossalai/tensor/_ops/linear.py index 7998e353d..e75f18609 100644 --- a/colossalai/gemini/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -1,11 +1,11 @@ import torch -from colossalai.gemini.tensor import stateful_op_impl -from ..stateful_tensor import StatefulTensorV2 +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor.colo_tensor import ColoTensor from packaging import version -@stateful_op_impl(torch.nn.functional.linear) -def stateful_linear(types, args, kwargs, pg): +@colo_op_impl(torch.nn.functional.linear) +def colo_linear(types, args, kwargs, pg): """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. This method computes a linear. """ @@ -19,11 +19,11 @@ def stateful_linear(types, args, kwargs, pg): bias = None else: bias = kwargs.get('bias', None) - if isinstance(bias, StatefulTensorV2): + if isinstance(bias, ColoTensor): bias = bias.torch_tensor() # Add communication logic before and after linear call. - if isinstance(weight, StatefulTensorV2): + if isinstance(weight, ColoTensor): return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) else: return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/colossalai/gemini/tensor/stateful_tensor.py b/colossalai/tensor/colo_tensor.py similarity index 51% rename from colossalai/gemini/tensor/stateful_tensor.py rename to colossalai/tensor/colo_tensor.py index dbfd088b2..47e693720 100644 --- a/colossalai/gemini/tensor/stateful_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,11 +1,11 @@ import torch -from .api import _STATEFUL_OPS +from .op_wrapper import _COLOSSAL_OPS -class StatefulTensorV2(object): +class ColoTensor(object): def __new__(cls, *args, **kwargs): - return super(StatefulTensorV2, cls).__new__(cls) + return super(ColoTensor, cls).__new__(cls) def __init__(self, t: torch.Tensor) -> None: self._torch_tensor = t @@ -15,16 +15,15 @@ class StatefulTensorV2(object): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - global _STATEFUL_OPS - if func in _STATEFUL_OPS: - # Find StatefulTensorV2 instance to get process_group. + global _COLOSSAL_OPS + if func in _COLOSSAL_OPS: for arg in args: - if isinstance(arg, StatefulTensorV2): - return _STATEFUL_OPS[func](types, args, kwargs, None) + if isinstance(arg, ColoTensor): + return _COLOSSAL_OPS[func](types, args, kwargs, None) for kwarg in kwargs.values(): - if isinstance(kwarg, StatefulTensorV2): - return _STATEFUL_OPS[func](types, args, kwargs, None) + if isinstance(kwarg, ColoTensor): + return _COLOSSAL_OPS[func](types, args, kwargs, None) raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and " - f"kwargs: {kwargs} not supported for StatefulTensorV2!") + f"kwargs: {kwargs} not supported for ColoTensor!") diff --git a/colossalai/gemini/tensor/__init__.py b/colossalai/tensor/op_wrapper.py similarity index 52% rename from colossalai/gemini/tensor/__init__.py rename to colossalai/tensor/op_wrapper.py index fcf909ba4..577c85353 100644 --- a/colossalai/gemini/tensor/__init__.py +++ b/colossalai/tensor/op_wrapper.py @@ -1,24 +1,39 @@ +from typing import ( + Callable, + Dict, +) import functools -from .api import ( - _register_stateful_op,) + +# Custom sharded ops +_COLOSSAL_OPS: Dict[str, Callable] = {} -def stateful_op_impl(func): +def _register_colo_op(op, func): + from inspect import signature + if len(signature(func).parameters) != 4: + raise TypeError(f'Custom stateful op function expects signature: ' + f'(types, args, kwargs, process_group), but received ' + f'signature: {signature(func)}') + global _COLOSSAL_OPS + _COLOSSAL_OPS[op] = func + + +def colo_op_impl(func): """ Provides a way for users to write their own custom operator. This - can be used to override existing StatefulTensorV2 operators or write a new - one not supported by StatefulTensorV2. If the operator in question is covered - by ``__torch_function__`` dispatch and has a StatefulTensorV2 as any of its + can be used to override existing ColoTensor operators or write a new + one not supported by ColoTensor. If the operator in question is covered + by ``__torch_function__`` dispatch and has a ColoTensor as any of its parameters, the function provided will be invoked for that operator. Example:: - >>> @stateful_op_impl(torch.nn.functional.linear) + >>> @colo_op_impl(torch.nn.functional.linear) >>> def my_custom_linear(types, args, kwargs, process_group): >>> .... >>> >>> input = torch.rand(10, 32) - >>> weight = StatefulTensorV2(torch.rand(32, 16)) - >>> bias = StatefulTensorV2(torch.rand(16)) + >>> weight = ColoTensor(torch.rand(32, 16)) + >>> bias = ColoTensor(torch.rand(16)) >>> # This will call `my_custom_linear` instead of the default. >>> torch.nn.functional.linear(input, weight, bias) @@ -32,7 +47,7 @@ def stateful_op_impl(func): """ def decorator_sharded_func(wrapped_func): - _register_stateful_op(func, wrapped_func) + _register_colo_op(func, wrapped_func) @functools.wraps(wrapped_func) def wrapper(*args, **kwargs): diff --git a/colossalai/gemini/tensor/utils.py b/colossalai/tensor/utils.py similarity index 64% rename from colossalai/gemini/tensor/utils.py rename to colossalai/tensor/utils.py index 869d1ad1c..1430e5191 100644 --- a/colossalai/gemini/tensor/utils.py +++ b/colossalai/tensor/utils.py @@ -1,14 +1,10 @@ import torch -import torch.distributed as dist -from torch.distributed import distributed_c10d -from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2 +from colossalai.tensor.colo_tensor import ColoTensor -def _convert_tensor(tensor: torch.Tensor) -> StatefulTensorV2: - if not tensor.is_contiguous(): - raise ValueError('input tensor is not a contiguous Tensor') - return StatefulTensorV2(tensor) +def _convert_tensor(tensor: torch.Tensor) -> ColoTensor: + return ColoTensor(tensor) def convert_parameter(module: torch.nn.Module, param_name: str): @@ -26,10 +22,10 @@ def convert_parameter(module: torch.nn.Module, param_name: str): st = _convert_tensor(tensor) - # Replace param with StatefulTensorV2. + # Replace param with ColoTensor. # Need to delete the attribute first since param_name might be - # torch.nn.Parameter and can't be replaced with StatefulTensorV2 which is + # torch.nn.Parameter and can't be replaced with ColoTensor which is # not torch.nn.Parameter. delattr(module, param_name) diff --git a/tests/test_gemini/test_tensor.py b/tests/test_tensor/test_op.py similarity index 74% rename from tests/test_gemini/test_tensor.py rename to tests/test_tensor/test_op.py index f403df5b4..4c9e72a92 100644 --- a/tests/test_gemini/test_tensor.py +++ b/tests/test_tensor/test_op.py @@ -1,10 +1,6 @@ from numpy import allclose import torch -from torch import nn -from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2 -# TODO(jiaruifang) auto import -from colossalai.gemini.tensor._ops import * -from colossalai.gemini.tensor.api import _STATEFUL_OPS +from colossalai.tensor import ColoTensor from copy import deepcopy @@ -18,8 +14,8 @@ def test_linear(): input_ref = torch.randn(1, in_dim) input_tensor = input_ref.clone() - sharded_weight = StatefulTensorV2(fc_ref.weight) - sharded_bias = StatefulTensorV2(fc_ref.bias) + sharded_weight = ColoTensor(fc_ref.weight) + sharded_bias = ColoTensor(fc_ref.bias) # replace the torch nn.Parameters with ShardedTensor delattr(fc, 'weight') @@ -45,15 +41,14 @@ def test_linear(): # The test case failed # def test_uniform(): -# t = StatefulTensorV2(torch.zeros(3, 5)) -# # print(_STATEFUL_OPS) +# t = ColoTensor(torch.zeros(3, 5)) # torch.nn.init.uniform_(t) # print(t) def test_element_wise(): t_ref = torch.randn(3, 5) - t = StatefulTensorV2(t_ref.clone()) + t = ColoTensor(t_ref.clone()) assert torch.mean(t) == torch.mean(t_ref) assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref)) assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))