[gemini] a new tensor structure (#818)

* Revert "[zero] add ZeroTensorShardStrategy (#793)"

This reverts commit 88759e289e.

* [gemini] set cpu memory capacity

* [log] local throughput collecting

* polish

* polish

* polish

* polish code

* polish

* polish code

* add a new tensor structure and override linear for it

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
pull/820/head
Jiarui Fang 2022-04-21 11:42:37 +08:00 committed by GitHub
parent 413ce30c45
commit ab962b9735
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 281 additions and 0 deletions

View File

@ -0,0 +1,43 @@
import functools
from .api import (
_register_stateful_op,)
def stateful_op_impl(func):
"""
Provides a way for users to write their own custom operator. This
can be used to override existing StatefulTensorV2 operators or write a new
one not supported by StatefulTensorV2. If the operator in question is covered
by ``__torch_function__`` dispatch and has a StatefulTensorV2 as any of its
parameters, the function provided will be invoked for that operator.
Example::
>>> @stateful_op_impl(torch.nn.functional.linear)
>>> def my_custom_linear(types, args, kwargs, process_group):
>>> ....
>>>
>>> input = torch.rand(10, 32)
>>> weight = StatefulTensorV2(torch.rand(32, 16))
>>> bias = StatefulTensorV2(torch.rand(16))
>>> # This will call `my_custom_linear` instead of the default.
>>> torch.nn.functional.linear(input, weight, bias)
The types, args and kwargs parameters are the same parameters that are
passed to ``__torch_function__`` dispatch API
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
Args:
func(Callable): Torch function for which we want to provide a sharded
implementation (ex: torch.nn.functional.linear)
"""
def decorator_sharded_func(wrapped_func):
_register_stateful_op(func, wrapped_func)
@functools.wraps(wrapped_func)
def wrapper(*args, **kwargs):
return wrapped_func(*args, **kwargs)
return wrapper
return decorator_sharded_func

View File

@ -0,0 +1,3 @@
from .init import stateful_uniform
from .linear import stateful_linear
from .element_wise import stateful_mean

View File

@ -0,0 +1,29 @@
import torch
from colossalai.gemini.tensor import stateful_op_impl
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
@stateful_op_impl(torch.mean)
def stateful_mean(types, args=(), kwargs=None, pg=None):
stateful_tensor = args[0]
return torch.mean(stateful_tensor.torch_tensor())
def register_elementwise_op(op):
@stateful_op_impl(op)
def elementwise_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor.
"""
input_tensor = args[0]
# Validate types
if not isinstance(input_tensor, StatefulTensorV2):
raise TypeError("input needs to be a StatefulTensorV2")
return op(input_tensor.torch_tensor())
register_elementwise_op(torch.nn.functional.gelu)
register_elementwise_op(torch.nn.functional.relu)

View File

@ -0,0 +1,29 @@
import torch
from colossalai.gemini.tensor import stateful_op_impl
def validate_param(param, param_name):
if param is None:
raise ValueError(f"param: {param_name} shouldn't be None!")
@stateful_op_impl(torch.nn.init.uniform_)
def stateful_uniform(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
sharded_tensor: tensor sharded across devices
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
"""
validate_param(kwargs, "kwargs")
stateful_tensor = kwargs["tensor"]
validate_param(stateful_tensor, "stateful_tensor")
a = kwargs['a']
validate_param(a, "a")
b = kwargs['b']
validate_param(b, "b")
torch.nn.init.uniform_(stateful_tensor.torch_tensor(), a=a, b=b)
return stateful_tensor

View File

@ -0,0 +1,29 @@
import torch
from colossalai.gemini.tensor import stateful_op_impl
from ..stateful_tensor import StatefulTensorV2
from packaging import version
@stateful_op_impl(torch.nn.functional.linear)
def stateful_linear(types, args, kwargs, pg):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor = args[0]
weight = args[1]
if version.parse(torch.__version__) > version.parse("1.11.0"):
if len(args) == 3:
bias = args[2]
else:
bias = None
else:
bias = kwargs.get('bias', None)
if isinstance(bias, StatefulTensorV2):
bias = bias.torch_tensor()
# Add communication logic before and after linear call.
if isinstance(weight, StatefulTensorV2):
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
else:
return torch.nn.functional.linear(input_tensor, weight, bias)

View File

@ -0,0 +1,17 @@
from typing import (
Callable,
Dict,
)
# Custom sharded ops
_STATEFUL_OPS: Dict[str, Callable] = {}
def _register_stateful_op(op, func):
from inspect import signature
if len(signature(func).parameters) != 4:
raise TypeError(f'Custom stateful op function expects signature: '
f'(types, args, kwargs, process_group), but received '
f'signature: {signature(func)}')
global _STATEFUL_OPS
_STATEFUL_OPS[op] = func

View File

@ -0,0 +1,30 @@
import torch
from .api import _STATEFUL_OPS
class StatefulTensorV2(object):
def __new__(cls, *args, **kwargs):
return super(StatefulTensorV2, cls).__new__(cls)
def __init__(self, t: torch.Tensor) -> None:
self._torch_tensor = t
def torch_tensor(self) -> torch.Tensor:
return self._torch_tensor
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
global _STATEFUL_OPS
if func in _STATEFUL_OPS:
# Find StatefulTensorV2 instance to get process_group.
for arg in args:
if isinstance(arg, StatefulTensorV2):
return _STATEFUL_OPS[func](types, args, kwargs, None)
for kwarg in kwargs.values():
if isinstance(kwarg, StatefulTensorV2):
return _STATEFUL_OPS[func](types, args, kwargs, None)
raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and "
f"kwargs: {kwargs} not supported for StatefulTensorV2!")

View File

@ -0,0 +1,37 @@
import torch
import torch.distributed as dist
from torch.distributed import distributed_c10d
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
def _convert_tensor(tensor: torch.Tensor) -> StatefulTensorV2:
if not tensor.is_contiguous():
raise ValueError('input tensor is not a contiguous Tensor')
return StatefulTensorV2(tensor)
def convert_parameter(module: torch.nn.Module, param_name: str):
# Perform some validation first.
if not hasattr(module, param_name):
raise ValueError(f'module: {module} does not have parameter with name: {param_name}')
tensor = getattr(module, param_name)
if not isinstance(tensor, torch.Tensor):
raise ValueError(
f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
if not tensor.is_contiguous():
raise ValueError(f'param: {param_name} is not a contiguous Tensor')
st = _convert_tensor(tensor)
# Replace param with StatefulTensorV2.
# Need to delete the attribute first since param_name might be
# torch.nn.Parameter and can't be replaced with StatefulTensorV2 which is
# not torch.nn.Parameter.
delattr(module, param_name)
# Now we can set the attribute appropriately.
setattr(module, param_name, st)

View File

@ -0,0 +1,64 @@
from numpy import allclose
import torch
from torch import nn
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
# TODO(jiaruifang) auto import
from colossalai.gemini.tensor._ops import *
from colossalai.gemini.tensor.api import _STATEFUL_OPS
from copy import deepcopy
def test_linear():
in_dim = 4
out_dim = 5
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
fc_ref = deepcopy(fc)
input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone()
sharded_weight = StatefulTensorV2(fc_ref.weight)
sharded_bias = StatefulTensorV2(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight')
setattr(fc, 'weight', sharded_weight)
delattr(fc, 'bias')
setattr(fc, 'bias', sharded_bias)
fc.weight.requires_grad = True
fc.bias.requires_grad = True
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
out = fc(input_tensor)
loss = out.sum()
loss.backward()
out_ref = fc_ref(input_ref)
loss_ref = out_ref.sum()
loss_ref.backward()
assert (loss_ref == loss)
assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad)
# The test case failed
# def test_uniform():
# t = StatefulTensorV2(torch.zeros(3, 5))
# # print(_STATEFUL_OPS)
# torch.nn.init.uniform_(t)
# print(t)
def test_element_wise():
t_ref = torch.randn(3, 5)
t = StatefulTensorV2(t_ref.clone())
assert torch.mean(t) == torch.mean(t_ref)
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
if __name__ == '__main__':
test_linear()
# test_element_wise()