[hotfix] ColoTensor pin_memory (#840)

pull/842/head^2
Jiarui Fang 2022-04-22 17:07:46 +08:00 committed by GitHub
parent 9f6f656952
commit 4575a3298b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 9 deletions

View File

@ -20,21 +20,27 @@ class ColoTensor(object):
dtype=None,
requires_grad=False,
pin_memory=False,
device=None,
torch_tensor=torch.empty(0),
):
self._size = size
self._dtype = dtype
self._requires_grad = requires_grad
self._pin_memory = pin_memory
self._device = device
self._torch_tensor = torch_tensor
def numel(self):
return sum(self._size)
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor):
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.pin_memory,
torch_tensor=tensor)
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0))
return colo_t
def del_torch_tensor(self) -> None:
@ -42,12 +48,12 @@ class ColoTensor(object):
self._torch_tensor = torch.empty(self._size)
def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor == None or self._torch_tensor.numel() == 0:
print(self._size, type(self._size))
if self._torch_tensor.numel() == 0:
self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype,
pin_memory=self._pin_memory,
requires_grad=self._requires_grad,
pin_memory=self._pin_memory)
device=self._device)
return self._torch_tensor
@classmethod
@ -67,7 +73,5 @@ class ColoTensor(object):
if kwargs is None:
kwargs = {}
kwargs = {
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
}
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs)