[tensor] ZeRO use ColoTensor as the base class. (#828)

* [refactor] moving InsertPostInitMethodToModuleSubClasses to utils.

* [tensor] ZeRO use ColoTensor as the base class.

* polish
pull/829/head
Jiarui Fang 2022-04-22 12:00:48 +08:00 committed by GitHub
parent 8e6fdb4f29
commit 294a6060d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 26 deletions

View File

@ -15,12 +15,12 @@ class ColoTensor(object):
return super(ColoTensor, cls).__new__(cls)
def __init__(
self,
*size: Tuple[int],
dtype=None,
requires_grad=False,
pin_memory=False,
torch_tensor=None,
self,
*size: Tuple[int],
dtype=None,
requires_grad=False,
pin_memory=False,
torch_tensor=torch.empty(0),
):
self._size = size
self._dtype = dtype
@ -37,8 +37,13 @@ class ColoTensor(object):
torch_tensor=tensor)
return colo_t
def del_torch_tensor(self) -> None:
self._size = (0,)
self._torch_tensor = torch.empty(self._size)
def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor == None:
if self._torch_tensor == None or self._torch_tensor.numel() == 0:
print(self._size, type(self._size))
self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype,
requires_grad=self._requires_grad,

View File

@ -20,8 +20,8 @@ class ShardedTensor(StatefulTensor):
@property
def dtype(self) -> torch.dtype:
assert self._payload.dtype == self._origin_dtype
return self._payload.dtype
assert self.torch_tensor().dtype == self._origin_dtype
return self.torch_tensor().dtype
@property
def origin_numel(self) -> int:

View File

@ -1,6 +1,7 @@
from enum import Enum
from typing import Optional
import torch
from colossalai.tensor import ColoTensor
class TensorState(Enum):
@ -11,7 +12,7 @@ class TensorState(Enum):
COMPUTE = 4
class StatefulTensor(object):
class StatefulTensor(ColoTensor):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
@ -20,15 +21,20 @@ class StatefulTensor(object):
"""
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
if tensor is not None:
super().__init__(tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, \
pin_memory=tensor.pin_memory, torch_tensor=tensor)
else:
super().__init__(0)
self._state = state
self._payload = tensor
if self._state == TensorState.FREE:
assert self._payload is None, f"payload has to None if state is {self._state}"
assert self.torch_tensor().numel() == 0, f"payload has to None if state is {self._state}"
def data_ptr(self):
if self._payload is None:
if self.torch_tensor().numel() == 0:
return None
return self._payload.data_ptr()
return self.torch_tensor().data_ptr()
@property
def state(self) -> TensorState:
@ -36,42 +42,41 @@ class StatefulTensor(object):
def set_null(self) -> None:
self._state = TensorState.FREE
self._payload = None
self.del_torch_tensor()
def is_null(self) -> bool:
if self._state == TensorState.FREE:
assert self._payload is None
assert self.torch_tensor().numel() == 0
return True
return False
def trans_state(self, state: TensorState) -> None:
self._state = state
if state == TensorState.FREE:
self._payload = None
self.del_torch_tensor()
@property
def payload(self) -> Optional[torch.Tensor]:
return self._payload
return self.torch_tensor()
def copy_payload(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))
self.torch_tensor.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor) -> None:
del self._payload
self._payload = tensor
self._torch_tensor = tensor
self.trans_state(TensorState.HOLD)
@property
def device(self) -> torch.device:
return self._payload.device
return self.torch_tensor().device
@property
def dtype(self) -> torch.dtype:
return self._payload.dtype
return self.torch_tensor().dtype
@property
def shape(self):
return self._payload.shape
return self.torch_tensor().shape
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")

View File

@ -60,8 +60,8 @@ def test_no_wrap_op():
assert torch.sum(input=t) == torch.sum(input=t_ref)
def test_lazy_init_tensor():
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor == None
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0
assert lazy_t.torch_tensor().numel() == 6
def check_all():