diff --git a/colossalai/gemini/tensor/__init__.py b/colossalai/gemini/tensor/__init__.py new file mode 100644 index 000000000..fcf909ba4 --- /dev/null +++ b/colossalai/gemini/tensor/__init__.py @@ -0,0 +1,43 @@ +import functools +from .api import ( + _register_stateful_op,) + + +def stateful_op_impl(func): + """ + Provides a way for users to write their own custom operator. This + can be used to override existing StatefulTensorV2 operators or write a new + one not supported by StatefulTensorV2. If the operator in question is covered + by ``__torch_function__`` dispatch and has a StatefulTensorV2 as any of its + parameters, the function provided will be invoked for that operator. + + Example:: + >>> @stateful_op_impl(torch.nn.functional.linear) + >>> def my_custom_linear(types, args, kwargs, process_group): + >>> .... + >>> + >>> input = torch.rand(10, 32) + >>> weight = StatefulTensorV2(torch.rand(32, 16)) + >>> bias = StatefulTensorV2(torch.rand(16)) + >>> # This will call `my_custom_linear` instead of the default. + >>> torch.nn.functional.linear(input, weight, bias) + + The types, args and kwargs parameters are the same parameters that are + passed to ``__torch_function__`` dispatch API + (https://pytorch.org/docs/stable/notes/extending.html#extending-torch). + + Args: + func(Callable): Torch function for which we want to provide a sharded + implementation (ex: torch.nn.functional.linear) + """ + + def decorator_sharded_func(wrapped_func): + _register_stateful_op(func, wrapped_func) + + @functools.wraps(wrapped_func) + def wrapper(*args, **kwargs): + return wrapped_func(*args, **kwargs) + + return wrapper + + return decorator_sharded_func diff --git a/colossalai/gemini/tensor/_ops/__init__.py b/colossalai/gemini/tensor/_ops/__init__.py new file mode 100644 index 000000000..199f456ee --- /dev/null +++ b/colossalai/gemini/tensor/_ops/__init__.py @@ -0,0 +1,3 @@ +from .init import stateful_uniform +from .linear import stateful_linear +from .element_wise import stateful_mean \ No newline at end of file diff --git a/colossalai/gemini/tensor/_ops/element_wise.py b/colossalai/gemini/tensor/_ops/element_wise.py new file mode 100644 index 000000000..773ce4799 --- /dev/null +++ b/colossalai/gemini/tensor/_ops/element_wise.py @@ -0,0 +1,29 @@ +import torch +from colossalai.gemini.tensor import stateful_op_impl +from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2 + + +@stateful_op_impl(torch.mean) +def stateful_mean(types, args=(), kwargs=None, pg=None): + stateful_tensor = args[0] + return torch.mean(stateful_tensor.torch_tensor()) + + +def register_elementwise_op(op): + + @stateful_op_impl(op) + def elementwise_op(types, args=(), kwargs=None, pg=None): + """ + Handles ``__torch_function__`` dispatch for the elementwise op such + as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. + This method computes on either a normal tensor or a sharded tensor. + """ + input_tensor = args[0] + # Validate types + if not isinstance(input_tensor, StatefulTensorV2): + raise TypeError("input needs to be a StatefulTensorV2") + return op(input_tensor.torch_tensor()) + + +register_elementwise_op(torch.nn.functional.gelu) +register_elementwise_op(torch.nn.functional.relu) diff --git a/colossalai/gemini/tensor/_ops/init.py b/colossalai/gemini/tensor/_ops/init.py new file mode 100644 index 000000000..079ffe7c3 --- /dev/null +++ b/colossalai/gemini/tensor/_ops/init.py @@ -0,0 +1,29 @@ +import torch +from colossalai.gemini.tensor import stateful_op_impl + + +def validate_param(param, param_name): + if param is None: + raise ValueError(f"param: {param_name} shouldn't be None!") + + +@stateful_op_impl(torch.nn.init.uniform_) +def stateful_uniform(types, args=(), kwargs=None, pg=None): + r""" + Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform + distribution :math:`\mathcal{U}(a, b)`. + Args: + sharded_tensor: tensor sharded across devices + a: the lower bound of the uniform distribution + b: the upper bound of the uniform distribution + """ + validate_param(kwargs, "kwargs") + stateful_tensor = kwargs["tensor"] + validate_param(stateful_tensor, "stateful_tensor") + a = kwargs['a'] + validate_param(a, "a") + b = kwargs['b'] + validate_param(b, "b") + + torch.nn.init.uniform_(stateful_tensor.torch_tensor(), a=a, b=b) + return stateful_tensor diff --git a/colossalai/gemini/tensor/_ops/linear.py b/colossalai/gemini/tensor/_ops/linear.py new file mode 100644 index 000000000..7998e353d --- /dev/null +++ b/colossalai/gemini/tensor/_ops/linear.py @@ -0,0 +1,29 @@ +import torch +from colossalai.gemini.tensor import stateful_op_impl +from ..stateful_tensor import StatefulTensorV2 +from packaging import version + + +@stateful_op_impl(torch.nn.functional.linear) +def stateful_linear(types, args, kwargs, pg): + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. + This method computes a linear. + """ + input_tensor = args[0] + weight = args[1] + + if version.parse(torch.__version__) > version.parse("1.11.0"): + if len(args) == 3: + bias = args[2] + else: + bias = None + else: + bias = kwargs.get('bias', None) + if isinstance(bias, StatefulTensorV2): + bias = bias.torch_tensor() + + # Add communication logic before and after linear call. + if isinstance(weight, StatefulTensorV2): + return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) + else: + return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/colossalai/gemini/tensor/api.py b/colossalai/gemini/tensor/api.py new file mode 100644 index 000000000..92a7e98fb --- /dev/null +++ b/colossalai/gemini/tensor/api.py @@ -0,0 +1,17 @@ +from typing import ( + Callable, + Dict, +) + +# Custom sharded ops +_STATEFUL_OPS: Dict[str, Callable] = {} + + +def _register_stateful_op(op, func): + from inspect import signature + if len(signature(func).parameters) != 4: + raise TypeError(f'Custom stateful op function expects signature: ' + f'(types, args, kwargs, process_group), but received ' + f'signature: {signature(func)}') + global _STATEFUL_OPS + _STATEFUL_OPS[op] = func diff --git a/colossalai/gemini/tensor/stateful_tensor.py b/colossalai/gemini/tensor/stateful_tensor.py new file mode 100644 index 000000000..dbfd088b2 --- /dev/null +++ b/colossalai/gemini/tensor/stateful_tensor.py @@ -0,0 +1,30 @@ +import torch +from .api import _STATEFUL_OPS + + +class StatefulTensorV2(object): + + def __new__(cls, *args, **kwargs): + return super(StatefulTensorV2, cls).__new__(cls) + + def __init__(self, t: torch.Tensor) -> None: + self._torch_tensor = t + + def torch_tensor(self) -> torch.Tensor: + return self._torch_tensor + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + global _STATEFUL_OPS + if func in _STATEFUL_OPS: + # Find StatefulTensorV2 instance to get process_group. + for arg in args: + if isinstance(arg, StatefulTensorV2): + return _STATEFUL_OPS[func](types, args, kwargs, None) + + for kwarg in kwargs.values(): + if isinstance(kwarg, StatefulTensorV2): + return _STATEFUL_OPS[func](types, args, kwargs, None) + + raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and " + f"kwargs: {kwargs} not supported for StatefulTensorV2!") diff --git a/colossalai/gemini/tensor/utils.py b/colossalai/gemini/tensor/utils.py new file mode 100644 index 000000000..869d1ad1c --- /dev/null +++ b/colossalai/gemini/tensor/utils.py @@ -0,0 +1,37 @@ +import torch +import torch.distributed as dist +from torch.distributed import distributed_c10d + +from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2 + + +def _convert_tensor(tensor: torch.Tensor) -> StatefulTensorV2: + if not tensor.is_contiguous(): + raise ValueError('input tensor is not a contiguous Tensor') + return StatefulTensorV2(tensor) + + +def convert_parameter(module: torch.nn.Module, param_name: str): + # Perform some validation first. + if not hasattr(module, param_name): + raise ValueError(f'module: {module} does not have parameter with name: {param_name}') + + tensor = getattr(module, param_name) + if not isinstance(tensor, torch.Tensor): + raise ValueError( + f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') + + if not tensor.is_contiguous(): + raise ValueError(f'param: {param_name} is not a contiguous Tensor') + + st = _convert_tensor(tensor) + + # Replace param with StatefulTensorV2. + + # Need to delete the attribute first since param_name might be + # torch.nn.Parameter and can't be replaced with StatefulTensorV2 which is + # not torch.nn.Parameter. + delattr(module, param_name) + + # Now we can set the attribute appropriately. + setattr(module, param_name, st) diff --git a/tests/test_gemini/test_tensor.py b/tests/test_gemini/test_tensor.py new file mode 100644 index 000000000..f403df5b4 --- /dev/null +++ b/tests/test_gemini/test_tensor.py @@ -0,0 +1,64 @@ +from numpy import allclose +import torch +from torch import nn +from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2 +# TODO(jiaruifang) auto import +from colossalai.gemini.tensor._ops import * +from colossalai.gemini.tensor.api import _STATEFUL_OPS +from copy import deepcopy + + +def test_linear(): + in_dim = 4 + out_dim = 5 + + fc = torch.nn.Linear(in_dim, out_dim, bias=True) + fc_ref = deepcopy(fc) + + input_ref = torch.randn(1, in_dim) + input_tensor = input_ref.clone() + + sharded_weight = StatefulTensorV2(fc_ref.weight) + sharded_bias = StatefulTensorV2(fc_ref.bias) + + # replace the torch nn.Parameters with ShardedTensor + delattr(fc, 'weight') + setattr(fc, 'weight', sharded_weight) + delattr(fc, 'bias') + setattr(fc, 'bias', sharded_bias) + + fc.weight.requires_grad = True + fc.bias.requires_grad = True + + # torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias) + out = fc(input_tensor) + loss = out.sum() + loss.backward() + + out_ref = fc_ref(input_ref) + loss_ref = out_ref.sum() + loss_ref.backward() + + assert (loss_ref == loss) + assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad) + + +# The test case failed +# def test_uniform(): +# t = StatefulTensorV2(torch.zeros(3, 5)) +# # print(_STATEFUL_OPS) +# torch.nn.init.uniform_(t) +# print(t) + + +def test_element_wise(): + t_ref = torch.randn(3, 5) + t = StatefulTensorV2(t_ref.clone()) + assert torch.mean(t) == torch.mean(t_ref) + assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref)) + assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref)) + + +if __name__ == '__main__': + test_linear() + # test_element_wise()