[tensor] lazy init (#823)

pull/824/head
Jiarui Fang 2022-04-21 15:40:23 +08:00 committed by GitHub
parent 68dcd51d41
commit 2ecc3d7a55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 8 deletions

View File

@ -1,16 +1,48 @@
import torch
from .op_wrapper import _COLOSSAL_OPS
from typing import Tuple
class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
2. It supports lazy init the tensor's payload.
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
"""
def __new__(cls, *args, **kwargs):
return super(ColoTensor, cls).__new__(cls)
def __init__(self, t: torch.Tensor) -> None:
self._torch_tensor = t
def __init__(
self,
*size: Tuple[int],
dtype=None,
requires_grad=False,
pin_memory=False,
torch_tensor=None,
):
self._size = size
self._dtype = dtype
self._requires_grad = requires_grad
self._pin_memory = pin_memory
self._torch_tensor = torch_tensor
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor):
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.pin_memory,
torch_tensor=tensor)
return colo_t
def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor == None:
self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype,
requires_grad=self._requires_grad,
pin_memory=self._pin_memory)
return self._torch_tensor
@classmethod

View File

@ -1,4 +1,4 @@
from numpy import allclose
from numpy import allclose, require
import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
@ -14,8 +14,8 @@ def test_linear():
input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone()
sharded_weight = ColoTensor(fc_ref.weight)
sharded_bias = ColoTensor(fc_ref.bias)
sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight')
@ -48,7 +48,7 @@ def test_linear():
def test_element_wise():
t_ref = torch.randn(3, 5)
t = ColoTensor(t_ref.clone())
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.mean(t) == torch.mean(t_ref)
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
@ -57,10 +57,16 @@ def test_element_wise():
# Test a function not wrapped by
def test_no_wrap_op():
t_ref = torch.randn(3, 5)
t = ColoTensor(t_ref.clone())
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.sum(t) == torch.sum(t_ref)
def test_lazy_init_tensor():
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor == None
assert lazy_t.torch_tensor().numel() == 6
if __name__ == '__main__':
test_no_wrap_op()
test_lazy_init_tensor()
# test_element_wise()