mirror of https://github.com/hpcaitech/ColossalAI
[tensor] lazy init (#823)
parent
68dcd51d41
commit
2ecc3d7a55
|
@ -1,16 +1,48 @@
|
||||||
import torch
|
import torch
|
||||||
from .op_wrapper import _COLOSSAL_OPS
|
from .op_wrapper import _COLOSSAL_OPS
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
class ColoTensor(object):
|
class ColoTensor(object):
|
||||||
|
""" Data Structure for Tensor in Colossal-AI
|
||||||
|
1. It contains a torch.Tensor as an attribute.
|
||||||
|
2. It supports lazy init the tensor's payload.
|
||||||
|
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
|
||||||
|
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
|
||||||
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
return super(ColoTensor, cls).__new__(cls)
|
return super(ColoTensor, cls).__new__(cls)
|
||||||
|
|
||||||
def __init__(self, t: torch.Tensor) -> None:
|
def __init__(
|
||||||
self._torch_tensor = t
|
self,
|
||||||
|
*size: Tuple[int],
|
||||||
|
dtype=None,
|
||||||
|
requires_grad=False,
|
||||||
|
pin_memory=False,
|
||||||
|
torch_tensor=None,
|
||||||
|
):
|
||||||
|
self._size = size
|
||||||
|
self._dtype = dtype
|
||||||
|
self._requires_grad = requires_grad
|
||||||
|
self._pin_memory = pin_memory
|
||||||
|
self._torch_tensor = torch_tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_from_torch_tensor(tensor: torch.Tensor):
|
||||||
|
colo_t = ColoTensor(*tensor.size(),
|
||||||
|
dtype=tensor.dtype,
|
||||||
|
requires_grad=tensor.requires_grad,
|
||||||
|
pin_memory=tensor.pin_memory,
|
||||||
|
torch_tensor=tensor)
|
||||||
|
return colo_t
|
||||||
|
|
||||||
def torch_tensor(self) -> torch.Tensor:
|
def torch_tensor(self) -> torch.Tensor:
|
||||||
|
if self._torch_tensor == None:
|
||||||
|
self._torch_tensor = torch.empty(*self._size,
|
||||||
|
dtype=self._dtype,
|
||||||
|
requires_grad=self._requires_grad,
|
||||||
|
pin_memory=self._pin_memory)
|
||||||
return self._torch_tensor
|
return self._torch_tensor
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from numpy import allclose
|
from numpy import allclose, require
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
@ -14,8 +14,8 @@ def test_linear():
|
||||||
input_ref = torch.randn(1, in_dim)
|
input_ref = torch.randn(1, in_dim)
|
||||||
input_tensor = input_ref.clone()
|
input_tensor = input_ref.clone()
|
||||||
|
|
||||||
sharded_weight = ColoTensor(fc_ref.weight)
|
sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
|
||||||
sharded_bias = ColoTensor(fc_ref.bias)
|
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
|
||||||
|
|
||||||
# replace the torch nn.Parameters with ShardedTensor
|
# replace the torch nn.Parameters with ShardedTensor
|
||||||
delattr(fc, 'weight')
|
delattr(fc, 'weight')
|
||||||
|
@ -48,7 +48,7 @@ def test_linear():
|
||||||
|
|
||||||
def test_element_wise():
|
def test_element_wise():
|
||||||
t_ref = torch.randn(3, 5)
|
t_ref = torch.randn(3, 5)
|
||||||
t = ColoTensor(t_ref.clone())
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||||
assert torch.mean(t) == torch.mean(t_ref)
|
assert torch.mean(t) == torch.mean(t_ref)
|
||||||
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
||||||
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
||||||
|
@ -57,10 +57,16 @@ def test_element_wise():
|
||||||
# Test a function not wrapped by
|
# Test a function not wrapped by
|
||||||
def test_no_wrap_op():
|
def test_no_wrap_op():
|
||||||
t_ref = torch.randn(3, 5)
|
t_ref = torch.randn(3, 5)
|
||||||
t = ColoTensor(t_ref.clone())
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||||
assert torch.sum(t) == torch.sum(t_ref)
|
assert torch.sum(t) == torch.sum(t_ref)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lazy_init_tensor():
|
||||||
|
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
|
||||||
|
assert lazy_t._torch_tensor == None
|
||||||
|
assert lazy_t.torch_tensor().numel() == 6
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_no_wrap_op()
|
test_lazy_init_tensor()
|
||||||
# test_element_wise()
|
# test_element_wise()
|
||||||
|
|
Loading…
Reference in New Issue