mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] ColoTensor pin_memory (#840)
parent
9f6f656952
commit
4575a3298b
|
@ -20,21 +20,27 @@ class ColoTensor(object):
|
|||
dtype=None,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
device=None,
|
||||
torch_tensor=torch.empty(0),
|
||||
):
|
||||
self._size = size
|
||||
self._dtype = dtype
|
||||
self._requires_grad = requires_grad
|
||||
self._pin_memory = pin_memory
|
||||
self._device = device
|
||||
self._torch_tensor = torch_tensor
|
||||
|
||||
def numel(self):
|
||||
return sum(self._size)
|
||||
|
||||
@staticmethod
|
||||
def init_from_torch_tensor(tensor: torch.Tensor):
|
||||
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
|
||||
colo_t = ColoTensor(*tensor.size(),
|
||||
dtype=tensor.dtype,
|
||||
requires_grad=tensor.requires_grad,
|
||||
pin_memory=tensor.pin_memory,
|
||||
torch_tensor=tensor)
|
||||
pin_memory=tensor.is_pinned(),
|
||||
device=tensor.device,
|
||||
torch_tensor=tensor if save_payload else torch.empty(0))
|
||||
return colo_t
|
||||
|
||||
def del_torch_tensor(self) -> None:
|
||||
|
@ -42,12 +48,12 @@ class ColoTensor(object):
|
|||
self._torch_tensor = torch.empty(self._size)
|
||||
|
||||
def torch_tensor(self) -> torch.Tensor:
|
||||
if self._torch_tensor == None or self._torch_tensor.numel() == 0:
|
||||
print(self._size, type(self._size))
|
||||
if self._torch_tensor.numel() == 0:
|
||||
self._torch_tensor = torch.empty(*self._size,
|
||||
dtype=self._dtype,
|
||||
pin_memory=self._pin_memory,
|
||||
requires_grad=self._requires_grad,
|
||||
pin_memory=self._pin_memory)
|
||||
device=self._device)
|
||||
return self._torch_tensor
|
||||
|
||||
@classmethod
|
||||
|
@ -67,7 +73,5 @@ class ColoTensor(object):
|
|||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
kwargs = {
|
||||
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
|
||||
}
|
||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||
return func(*args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue