mirror of https://github.com/hpcaitech/ColossalAI
[tensor] lazy init (#823)
parent
68dcd51d41
commit
2ecc3d7a55
|
@ -1,16 +1,48 @@
|
|||
import torch
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class ColoTensor(object):
|
||||
""" Data Structure for Tensor in Colossal-AI
|
||||
1. It contains a torch.Tensor as an attribute.
|
||||
2. It supports lazy init the tensor's payload.
|
||||
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
|
||||
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
return super(ColoTensor, cls).__new__(cls)
|
||||
|
||||
def __init__(self, t: torch.Tensor) -> None:
|
||||
self._torch_tensor = t
|
||||
def __init__(
|
||||
self,
|
||||
*size: Tuple[int],
|
||||
dtype=None,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
torch_tensor=None,
|
||||
):
|
||||
self._size = size
|
||||
self._dtype = dtype
|
||||
self._requires_grad = requires_grad
|
||||
self._pin_memory = pin_memory
|
||||
self._torch_tensor = torch_tensor
|
||||
|
||||
@staticmethod
|
||||
def init_from_torch_tensor(tensor: torch.Tensor):
|
||||
colo_t = ColoTensor(*tensor.size(),
|
||||
dtype=tensor.dtype,
|
||||
requires_grad=tensor.requires_grad,
|
||||
pin_memory=tensor.pin_memory,
|
||||
torch_tensor=tensor)
|
||||
return colo_t
|
||||
|
||||
def torch_tensor(self) -> torch.Tensor:
|
||||
if self._torch_tensor == None:
|
||||
self._torch_tensor = torch.empty(*self._size,
|
||||
dtype=self._dtype,
|
||||
requires_grad=self._requires_grad,
|
||||
pin_memory=self._pin_memory)
|
||||
return self._torch_tensor
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from numpy import allclose
|
||||
from numpy import allclose, require
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from copy import deepcopy
|
||||
|
@ -14,8 +14,8 @@ def test_linear():
|
|||
input_ref = torch.randn(1, in_dim)
|
||||
input_tensor = input_ref.clone()
|
||||
|
||||
sharded_weight = ColoTensor(fc_ref.weight)
|
||||
sharded_bias = ColoTensor(fc_ref.bias)
|
||||
sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
|
||||
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
|
||||
|
||||
# replace the torch nn.Parameters with ShardedTensor
|
||||
delattr(fc, 'weight')
|
||||
|
@ -48,7 +48,7 @@ def test_linear():
|
|||
|
||||
def test_element_wise():
|
||||
t_ref = torch.randn(3, 5)
|
||||
t = ColoTensor(t_ref.clone())
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
assert torch.mean(t) == torch.mean(t_ref)
|
||||
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
||||
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
||||
|
@ -57,10 +57,16 @@ def test_element_wise():
|
|||
# Test a function not wrapped by
|
||||
def test_no_wrap_op():
|
||||
t_ref = torch.randn(3, 5)
|
||||
t = ColoTensor(t_ref.clone())
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
assert torch.sum(t) == torch.sum(t_ref)
|
||||
|
||||
|
||||
def test_lazy_init_tensor():
|
||||
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
|
||||
assert lazy_t._torch_tensor == None
|
||||
assert lazy_t.torch_tensor().numel() == 6
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_no_wrap_op()
|
||||
test_lazy_init_tensor()
|
||||
# test_element_wise()
|
||||
|
|
Loading…
Reference in New Issue