mirror of https://github.com/hpcaitech/ColossalAI
[tensor] ZeRO use ColoTensor as the base class. (#828)
* [refactor] moving InsertPostInitMethodToModuleSubClasses to utils. * [tensor] ZeRO use ColoTensor as the base class. * polishpull/829/head
parent
8e6fdb4f29
commit
294a6060d0
|
@ -15,12 +15,12 @@ class ColoTensor(object):
|
||||||
return super(ColoTensor, cls).__new__(cls)
|
return super(ColoTensor, cls).__new__(cls)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*size: Tuple[int],
|
*size: Tuple[int],
|
||||||
dtype=None,
|
dtype=None,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
torch_tensor=None,
|
torch_tensor=torch.empty(0),
|
||||||
):
|
):
|
||||||
self._size = size
|
self._size = size
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
|
@ -37,8 +37,13 @@ class ColoTensor(object):
|
||||||
torch_tensor=tensor)
|
torch_tensor=tensor)
|
||||||
return colo_t
|
return colo_t
|
||||||
|
|
||||||
|
def del_torch_tensor(self) -> None:
|
||||||
|
self._size = (0,)
|
||||||
|
self._torch_tensor = torch.empty(self._size)
|
||||||
|
|
||||||
def torch_tensor(self) -> torch.Tensor:
|
def torch_tensor(self) -> torch.Tensor:
|
||||||
if self._torch_tensor == None:
|
if self._torch_tensor == None or self._torch_tensor.numel() == 0:
|
||||||
|
print(self._size, type(self._size))
|
||||||
self._torch_tensor = torch.empty(*self._size,
|
self._torch_tensor = torch.empty(*self._size,
|
||||||
dtype=self._dtype,
|
dtype=self._dtype,
|
||||||
requires_grad=self._requires_grad,
|
requires_grad=self._requires_grad,
|
||||||
|
|
|
@ -20,8 +20,8 @@ class ShardedTensor(StatefulTensor):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
assert self._payload.dtype == self._origin_dtype
|
assert self.torch_tensor().dtype == self._origin_dtype
|
||||||
return self._payload.dtype
|
return self.torch_tensor().dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def origin_numel(self) -> int:
|
def origin_numel(self) -> int:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
|
from colossalai.tensor import ColoTensor
|
||||||
|
|
||||||
|
|
||||||
class TensorState(Enum):
|
class TensorState(Enum):
|
||||||
|
@ -11,7 +12,7 @@ class TensorState(Enum):
|
||||||
COMPUTE = 4
|
COMPUTE = 4
|
||||||
|
|
||||||
|
|
||||||
class StatefulTensor(object):
|
class StatefulTensor(ColoTensor):
|
||||||
"""A Structure stores a Torch Tensor and labeled states.
|
"""A Structure stores a Torch Tensor and labeled states.
|
||||||
Inspired from the paper:
|
Inspired from the paper:
|
||||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||||
|
@ -20,15 +21,20 @@ class StatefulTensor(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||||
|
if tensor is not None:
|
||||||
|
super().__init__(tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, \
|
||||||
|
pin_memory=tensor.pin_memory, torch_tensor=tensor)
|
||||||
|
else:
|
||||||
|
super().__init__(0)
|
||||||
|
|
||||||
self._state = state
|
self._state = state
|
||||||
self._payload = tensor
|
|
||||||
if self._state == TensorState.FREE:
|
if self._state == TensorState.FREE:
|
||||||
assert self._payload is None, f"payload has to None if state is {self._state}"
|
assert self.torch_tensor().numel() == 0, f"payload has to None if state is {self._state}"
|
||||||
|
|
||||||
def data_ptr(self):
|
def data_ptr(self):
|
||||||
if self._payload is None:
|
if self.torch_tensor().numel() == 0:
|
||||||
return None
|
return None
|
||||||
return self._payload.data_ptr()
|
return self.torch_tensor().data_ptr()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> TensorState:
|
def state(self) -> TensorState:
|
||||||
|
@ -36,42 +42,41 @@ class StatefulTensor(object):
|
||||||
|
|
||||||
def set_null(self) -> None:
|
def set_null(self) -> None:
|
||||||
self._state = TensorState.FREE
|
self._state = TensorState.FREE
|
||||||
self._payload = None
|
self.del_torch_tensor()
|
||||||
|
|
||||||
def is_null(self) -> bool:
|
def is_null(self) -> bool:
|
||||||
if self._state == TensorState.FREE:
|
if self._state == TensorState.FREE:
|
||||||
assert self._payload is None
|
assert self.torch_tensor().numel() == 0
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def trans_state(self, state: TensorState) -> None:
|
def trans_state(self, state: TensorState) -> None:
|
||||||
self._state = state
|
self._state = state
|
||||||
if state == TensorState.FREE:
|
if state == TensorState.FREE:
|
||||||
self._payload = None
|
self.del_torch_tensor()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def payload(self) -> Optional[torch.Tensor]:
|
def payload(self) -> Optional[torch.Tensor]:
|
||||||
return self._payload
|
return self.torch_tensor()
|
||||||
|
|
||||||
def copy_payload(self, tensor) -> None:
|
def copy_payload(self, tensor) -> None:
|
||||||
self._payload.view(-1).copy_(tensor.view(-1))
|
self.torch_tensor.view(-1).copy_(tensor.view(-1))
|
||||||
|
|
||||||
def reset_payload(self, tensor) -> None:
|
def reset_payload(self, tensor) -> None:
|
||||||
del self._payload
|
self._torch_tensor = tensor
|
||||||
self._payload = tensor
|
|
||||||
self.trans_state(TensorState.HOLD)
|
self.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return self._payload.device
|
return self.torch_tensor().device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
return self._payload.dtype
|
return self.torch_tensor().dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
return self._payload.shape
|
return self.torch_tensor().shape
|
||||||
|
|
||||||
def to(self, device: torch.device):
|
def to(self, device: torch.device):
|
||||||
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
||||||
|
|
|
@ -60,8 +60,8 @@ def test_no_wrap_op():
|
||||||
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||||
|
|
||||||
def test_lazy_init_tensor():
|
def test_lazy_init_tensor():
|
||||||
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
|
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
||||||
assert lazy_t._torch_tensor == None
|
assert lazy_t._torch_tensor.numel() == 0
|
||||||
assert lazy_t.torch_tensor().numel() == 6
|
assert lazy_t.torch_tensor().numel() == 6
|
||||||
|
|
||||||
def check_all():
|
def check_all():
|
||||||
|
|
Loading…
Reference in New Issue